[Mlir-commits] [mlir] 24aa4ef - [mlir] Print 0 element DenseElementsAttr as dense<> to fix parser bugs with expected shape.
River Riddle
llvmlistbot at llvm.org
Wed Jul 8 18:48:10 PDT 2020
Author: River Riddle
Date: 2020-07-08T18:44:23-07:00
New Revision: 24aa4efffd831a1125b4eb835e1911fa38f501d7
URL: https://github.com/llvm/llvm-project/commit/24aa4efffd831a1125b4eb835e1911fa38f501d7
DIFF: https://github.com/llvm/llvm-project/commit/24aa4efffd831a1125b4eb835e1911fa38f501d7.diff
LOG: [mlir] Print 0 element DenseElementsAttr as dense<> to fix parser bugs with expected shape.
Depending on where the 0 dimension is within the shape, the parser will currently reject .mlir generated by the printer.
Differential Revision: https://reviews.llvm.org/D83445
Added:
Modified:
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/test/IR/parser.mlir
Removed:
################################################################################
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 881f77f6004a..09135021a732 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1432,10 +1432,12 @@ void ModulePrinter::printAttribute(Attribute attr,
break;
}
os << "sparse<";
- printDenseIntOrFPElementsAttr(elementsAttr.getIndices(),
- /*allowHex=*/false);
- os << ", ";
- printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
+ DenseIntElementsAttr indices = elementsAttr.getIndices();
+ if (indices.getNumElements() != 0) {
+ printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
+ os << ", ";
+ printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
+ }
os << '>';
break;
}
@@ -1476,20 +1478,15 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
// Special case for degenerate tensors.
auto numElements = type.getNumElements();
- int64_t rank = type.getRank();
- if (numElements == 0) {
- for (int i = 0; i < rank; ++i)
- os << '[';
- for (int i = 0; i < rank; ++i)
- os << ']';
+ if (numElements == 0)
return;
- }
// We use a mixed-radix counter to iterate through the shape. When we bump a
// non-least-significant digit, we emit a close bracket. When we next emit an
// element we re-open all closed brackets.
// The mixed-radix counter, with radices in 'shape'.
+ int64_t rank = type.getRank();
SmallVector<unsigned, 4> counter(rank, 0);
// The number of brackets that have been opened and not closed.
unsigned openBrackets = 0;
diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 65613a149ae9..e2860b115231 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -758,13 +758,13 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
if (parseToken(Token::less, "expected '<' after 'dense'"))
return nullptr;
- // Parse the literal data.
+ // Parse the literal data if necessary.
TensorLiteralParser literalParser(*this);
- if (literalParser.parse(/*allowHex=*/true))
- return nullptr;
-
- if (parseToken(Token::greater, "expected '>'"))
- return nullptr;
+ if (!consumeIf(Token::greater)) {
+ if (literalParser.parse(/*allowHex=*/true) ||
+ parseToken(Token::greater, "expected '>'"))
+ return nullptr;
+ }
auto typeLoc = getToken().getLoc();
auto type = parseElementsLiteralType(attrType);
@@ -841,6 +841,25 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
if (parseToken(Token::less, "Expected '<' after 'sparse'"))
return nullptr;
+ // Check for the case where all elements are sparse. The indices are
+ // represented by a 2-dimensional shape where the second dimension is the rank
+ // of the type.
+ Type indiceEltType = builder.getIntegerType(64);
+ if (consumeIf(Token::greater)) {
+ ShapedType type = parseElementsLiteralType(attrType);
+ if (!type)
+ return nullptr;
+
+ // Construct the sparse elements attr using zero element indice/value
+ // attributes.
+ ShapedType indicesType =
+ RankedTensorType::get({0, type.getRank()}, indiceEltType);
+ ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
+ return SparseElementsAttr::get(
+ type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
+ DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
+ }
+
/// Parse the indices. We don't allow hex values here as we may need to use
/// the inferred shape.
auto indicesLoc = getToken().getLoc();
@@ -869,7 +888,6 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
// 2-dimensional shape where the second dimension is the rank of the type.
// Given that the parsed indices is a splat, we know that we only have one
// indice and thus one for the first dimension.
- auto indiceEltType = builder.getIntegerType(64);
ShapedType indicesType;
if (indiceParser.getShape().empty()) {
indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index bf371fd70f44..300fb7850e33 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -670,19 +670,21 @@ func @densetensorattr() -> () {
// CHECK: "fooi67"() {bar = dense<{{\[\[\[}}-5, 4, 6, 2]]]> : vector<1x1x4xi67>} : () -> ()
"fooi67"(){bar = dense<[[[-5, 4, 6, 2]]]> : vector<1x1x4xi67>} : () -> ()
-// CHECK: "foo2"() {bar = dense<[]> : tensor<0xi32>} : () -> ()
- "foo2"(){bar = dense<[]> : tensor<0xi32>} : () -> ()
-// CHECK: "foo2"() {bar = dense<{{\[\[}}]]> : tensor<1x0xi32>} : () -> ()
- "foo2"(){bar = dense<[[]]> : tensor<1x0xi32>} : () -> ()
+// CHECK: "foo2"() {bar = dense<> : tensor<0xi32>} : () -> ()
+ "foo2"(){bar = dense<> : tensor<0xi32>} : () -> ()
+// CHECK: "foo2"() {bar = dense<> : tensor<1x0xi32>} : () -> ()
+ "foo2"(){bar = dense<> : tensor<1x0xi32>} : () -> ()
+// CHECK: dense<> : tensor<0x512x512xi32>
+ "foo2"(){bar = dense<> : tensor<0x512x512xi32>} : () -> ()
// CHECK: "foo3"() {bar = dense<{{\[\[\[}}5, -6, 1, 2]], {{\[\[}}7, 8, 3, 4]]]> : tensor<2x1x4xi32>} : () -> ()
"foo3"(){bar = dense<[[[5, -6, 1, 2]], [[7, 8, 3, 4]]]> : tensor<2x1x4xi32>} : () -> ()
// CHECK: "float1"() {bar = dense<5.000000e+00> : tensor<1x1x1xf32>} : () -> ()
"float1"(){bar = dense<[[[5.0]]]> : tensor<1x1x1xf32>} : () -> ()
-// CHECK: "float2"() {bar = dense<[]> : tensor<0xf32>} : () -> ()
- "float2"(){bar = dense<[]> : tensor<0xf32>} : () -> ()
-// CHECK: "float2"() {bar = dense<{{\[\[}}]]> : tensor<1x0xf32>} : () -> ()
- "float2"(){bar = dense<[[]]> : tensor<1x0xf32>} : () -> ()
+// CHECK: "float2"() {bar = dense<> : tensor<0xf32>} : () -> ()
+ "float2"(){bar = dense<> : tensor<0xf32>} : () -> ()
+// CHECK: "float2"() {bar = dense<> : tensor<1x0xf32>} : () -> ()
+ "float2"(){bar = dense<> : tensor<1x0xf32>} : () -> ()
// CHECK: "bfloat16"() {bar = dense<{{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]> : tensor<2x1x4xbf16>} : () -> ()
"bfloat16"(){bar = dense<[[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]> : tensor<2x1x4xbf16>} : () -> ()
@@ -752,27 +754,27 @@ func @sparsetensorattr() -> () {
"fooi8"(){bar = sparse<0, -2> : tensor<1x1x1xi8>} : () -> ()
// CHECK: "fooi16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]> : tensor<2x2x2xi16>} : () -> ()
"fooi16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]> : tensor<2x2x2xi16>} : () -> ()
-// CHECK: "fooi32"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<1x1xi32>} : () -> ()
- "fooi32"(){bar = sparse<[], []> : tensor<1x1xi32>} : () -> ()
+// CHECK: "fooi32"() {bar = sparse<> : tensor<1x1xi32>} : () -> ()
+ "fooi32"(){bar = sparse<> : tensor<1x1xi32>} : () -> ()
// CHECK: "fooi64"() {bar = sparse<0, -1> : tensor<1xi64>} : () -> ()
"fooi64"(){bar = sparse<[[0]], [-1]> : tensor<1xi64>} : () -> ()
-// CHECK: "foo2"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<0xi32>} : () -> ()
- "foo2"(){bar = sparse<[], []> : tensor<0xi32>} : () -> ()
-// CHECK: "foo3"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<i32>} : () -> ()
- "foo3"(){bar = sparse<[], []> : tensor<i32>} : () -> ()
+// CHECK: "foo2"() {bar = sparse<> : tensor<0xi32>} : () -> ()
+ "foo2"(){bar = sparse<> : tensor<0xi32>} : () -> ()
+// CHECK: "foo3"() {bar = sparse<> : tensor<i32>} : () -> ()
+ "foo3"(){bar = sparse<> : tensor<i32>} : () -> ()
// CHECK: "foof16"() {bar = sparse<0, -2.000000e+00> : tensor<1x1x1xf16>} : () -> ()
"foof16"(){bar = sparse<0, -2.0> : tensor<1x1x1xf16>} : () -> ()
// CHECK: "foobf16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2.000000e+00, -1.000000e+00, 5.000000e+00]> : tensor<2x2x2xbf16>} : () -> ()
"foobf16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2.0, -1.0, 5.0]> : tensor<2x2x2xbf16>} : () -> ()
-// CHECK: "foof32"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<1x0x1xf32>} : () -> ()
- "foof32"(){bar = sparse<[], []> : tensor<1x0x1xf32>} : () -> ()
+// CHECK: "foof32"() {bar = sparse<> : tensor<1x0x1xf32>} : () -> ()
+ "foof32"(){bar = sparse<> : tensor<1x0x1xf32>} : () -> ()
// CHECK: "foof64"() {bar = sparse<0, -1.000000e+00> : tensor<1xf64>} : () -> ()
"foof64"(){bar = sparse<[[0]], [-1.0]> : tensor<1xf64>} : () -> ()
-// CHECK: "foof320"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<0xf32>} : () -> ()
- "foof320"(){bar = sparse<[], []> : tensor<0xf32>} : () -> ()
-// CHECK: "foof321"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<f32>} : () -> ()
- "foof321"(){bar = sparse<[], []> : tensor<f32>} : () -> ()
+// CHECK: "foof320"() {bar = sparse<> : tensor<0xf32>} : () -> ()
+ "foof320"(){bar = sparse<> : tensor<0xf32>} : () -> ()
+// CHECK: "foof321"() {bar = sparse<> : tensor<f32>} : () -> ()
+ "foof321"(){bar = sparse<> : tensor<f32>} : () -> ()
// CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
"foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
@@ -789,8 +791,8 @@ func @sparsevectorattr() -> () {
"fooi8"(){bar = sparse<0, -2> : vector<1x1x1xi8>} : () -> ()
// CHECK: "fooi16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]> : vector<2x2x2xi16>} : () -> ()
"fooi16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]> : vector<2x2x2xi16>} : () -> ()
-// CHECK: "fooi32"() {bar = sparse<{{\[}}], {{\[}}]> : vector<1x1xi32>} : () -> ()
- "fooi32"(){bar = sparse<[], []> : vector<1x1xi32>} : () -> ()
+// CHECK: "fooi32"() {bar = sparse<> : vector<1x1xi32>} : () -> ()
+ "fooi32"(){bar = sparse<> : vector<1x1xi32>} : () -> ()
// CHECK: "fooi64"() {bar = sparse<0, -1> : vector<1xi64>} : () -> ()
"fooi64"(){bar = sparse<[[0]], [-1]> : vector<1xi64>} : () -> ()
More information about the Mlir-commits
mailing list