[Mlir-commits] [mlir] c6390f1 - [mlir] Fix AsmPrinter for types with sub elements
Vladislav Vinogradov
llvmlistbot at llvm.org
Tue Oct 12 02:18:06 PDT 2021
Author: Vladislav Vinogradov
Date: 2021-10-12T12:08:16+03:00
New Revision: c6390f19f20c93687d6366f6e78d6e96f1f0a126
URL: https://github.com/llvm/llvm-project/commit/c6390f19f20c93687d6366f6e78d6e96f1f0a126
DIFF: https://github.com/llvm/llvm-project/commit/c6390f19f20c93687d6366f6e78d6e96f1f0a126.diff
LOG: [mlir] Fix AsmPrinter for types with sub elements
Call `printType(subElemType)` instead of `os << subElemType` for them.
It allows to handle type aliases inside complex types.
As a side effect, fixed `test.int` parsing.
Reviewed By: rriddle
Differential Revision: https://reviews.llvm.org/D111536
Added:
Modified:
mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/print-attr-type-aliases.mlir
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestTypes.cpp
mlir/test/mlir-tblgen/testdialect-typedefs.mlir
Removed:
################################################################################
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 3d9090daaf8b..67f135ff4bc1 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1921,7 +1921,7 @@ void AsmPrinter::Impl::printType(Type type) {
os << ") -> ";
ArrayRef<Type> results = funcTy.getResults();
if (results.size() == 1 && !results[0].isa<FunctionType>()) {
- os << results[0];
+ printType(results[0]);
} else {
os << '(';
interleaveComma(results, [&](Type ty) { printType(ty); });
@@ -1932,7 +1932,8 @@ void AsmPrinter::Impl::printType(Type type) {
os << "vector<";
for (int64_t dim : vectorTy.getShape())
os << dim << 'x';
- os << vectorTy.getElementType() << '>';
+ printType(vectorTy.getElementType());
+ os << '>';
})
.Case<RankedTensorType>([&](RankedTensorType tensorTy) {
os << "tensor<";
@@ -1943,7 +1944,7 @@ void AsmPrinter::Impl::printType(Type type) {
os << dim;
os << 'x';
}
- os << tensorTy.getElementType();
+ printType(tensorTy.getElementType());
// Only print the encoding attribute value if set.
if (tensorTy.getEncoding()) {
os << ", ";
diff --git a/mlir/test/IR/print-attr-type-aliases.mlir b/mlir/test/IR/print-attr-type-aliases.mlir
index 5bb408d77680..286f62c192b9 100644
--- a/mlir/test/IR/print-attr-type-aliases.mlir
+++ b/mlir/test/IR/print-attr-type-aliases.mlir
@@ -24,3 +24,7 @@
// CHECK-DAG: #test_encoding = "alias_test:tensor_encoding"
// CHECK-DAG: tensor<32xf32, #test_encoding>
"test.op"() : () -> tensor<32xf32, "alias_test:tensor_encoding">
+
+// CHECK-DAG: !test_ui8_ = type !test.int<unsigned, 8>
+// CHECK-DAG: tensor<32x!test_ui8_>
+"test.op"() : () -> tensor<32x!test.int<unsigned, 8>>
diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index a0f88452b3d1..658ed0957be3 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -88,6 +88,14 @@ struct TestOpAsmInterface : public OpAsmDialectInterface {
return AliasResult::FinalAlias;
}
}
+ if (auto intType = type.dyn_cast<TestIntegerType>()) {
+ if (intType.getSignedness() ==
+ TestIntegerType::SignednessSemantics::Unsigned &&
+ intType.getWidth() == 8) {
+ os << "test_ui8";
+ return AliasResult::FinalAlias;
+ }
+ }
return AliasResult::NoAlias;
}
diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp
index 9d618ce1aad6..7278f52a5865 100644
--- a/mlir/test/lib/Dialect/Test/TestTypes.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp
@@ -32,14 +32,13 @@ parseSignedness(DialectAsmParser &parser,
auto loc = parser.getCurrentLocation();
if (parser.parseKeyword(&signStr))
return failure();
- if (signStr.compare_insensitive("u") ||
- signStr.compare_insensitive("unsigned"))
+ if (signStr.equals_insensitive("u") || signStr.equals_insensitive("unsigned"))
result = TestIntegerType::SignednessSemantics::Unsigned;
- else if (signStr.compare_insensitive("s") ||
- signStr.compare_insensitive("signed"))
+ else if (signStr.equals_insensitive("s") ||
+ signStr.equals_insensitive("signed"))
result = TestIntegerType::SignednessSemantics::Signed;
- else if (signStr.compare_insensitive("n") ||
- signStr.compare_insensitive("none"))
+ else if (signStr.equals_insensitive("n") ||
+ signStr.equals_insensitive("none"))
result = TestIntegerType::SignednessSemantics::Signless;
else
return parser.emitError(loc, "expected signed, unsigned, or none");
diff --git a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
index c8500e47b695..783b4be704c7 100644
--- a/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
+++ b/mlir/test/mlir-tblgen/testdialect-typedefs.mlir
@@ -13,12 +13,12 @@ func @compoundA(%A : !test.cmpnd_a<1, !test.smpla, [5, 6]>)-> () {
return
}
-// CHECK: @testInt(%arg0: !test.int<unsigned, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<unsigned, 1>)
+// CHECK: @testInt(%arg0: !test.int<signed, 8>, %arg1: !test.int<unsigned, 2>, %arg2: !test.int<none, 1>)
func @testInt(%A : !test.int<s, 8>, %B : !test.int<unsigned, 2>, %C : !test.int<n, 1>) {
return
}
-// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int<unsigned, 3>}>)
+// CHECK: @structTest(%arg0: !test.struct<{field1,!test.smpla},{field2,!test.int<none, 3>}>)
func @structTest (%A : !test.struct< {field1, !test.smpla}, {field2, !test.int<none, 3>} > ) {
return
}
More information about the Mlir-commits
mailing list