[Mlir-commits] [mlir] 24aa4ef - [mlir] Print 0 element DenseElementsAttr as dense<> to fix parser bugs with expected shape.

River Riddle llvmlistbot at llvm.org
Wed Jul 8 18:48:10 PDT 2020


Author: River Riddle
Date: 2020-07-08T18:44:23-07:00
New Revision: 24aa4efffd831a1125b4eb835e1911fa38f501d7

URL: https://github.com/llvm/llvm-project/commit/24aa4efffd831a1125b4eb835e1911fa38f501d7
DIFF: https://github.com/llvm/llvm-project/commit/24aa4efffd831a1125b4eb835e1911fa38f501d7.diff

LOG: [mlir] Print 0 element DenseElementsAttr as dense<> to fix parser bugs with expected shape.

Depending on where the 0 dimension is within the shape, the parser will currently reject .mlir generated by the printer.

Differential Revision: https://reviews.llvm.org/D83445

Added: 
    

Modified: 
    mlir/lib/IR/AsmPrinter.cpp
    mlir/lib/Parser/AttributeParser.cpp
    mlir/test/IR/parser.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 881f77f6004a..09135021a732 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -1432,10 +1432,12 @@ void ModulePrinter::printAttribute(Attribute attr,
       break;
     }
     os << "sparse<";
-    printDenseIntOrFPElementsAttr(elementsAttr.getIndices(),
-                                  /*allowHex=*/false);
-    os << ", ";
-    printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
+    DenseIntElementsAttr indices = elementsAttr.getIndices();
+    if (indices.getNumElements() != 0) {
+      printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
+      os << ", ";
+      printDenseElementsAttr(elementsAttr.getValues(), /*allowHex=*/true);
+    }
     os << '>';
     break;
   }
@@ -1476,20 +1478,15 @@ printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
 
   // Special case for degenerate tensors.
   auto numElements = type.getNumElements();
-  int64_t rank = type.getRank();
-  if (numElements == 0) {
-    for (int i = 0; i < rank; ++i)
-      os << '[';
-    for (int i = 0; i < rank; ++i)
-      os << ']';
+  if (numElements == 0)
     return;
-  }
 
   // We use a mixed-radix counter to iterate through the shape. When we bump a
   // non-least-significant digit, we emit a close bracket. When we next emit an
   // element we re-open all closed brackets.
 
   // The mixed-radix counter, with radices in 'shape'.
+  int64_t rank = type.getRank();
   SmallVector<unsigned, 4> counter(rank, 0);
   // The number of brackets that have been opened and not closed.
   unsigned openBrackets = 0;

diff  --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp
index 65613a149ae9..e2860b115231 100644
--- a/mlir/lib/Parser/AttributeParser.cpp
+++ b/mlir/lib/Parser/AttributeParser.cpp
@@ -758,13 +758,13 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) {
   if (parseToken(Token::less, "expected '<' after 'dense'"))
     return nullptr;
 
-  // Parse the literal data.
+  // Parse the literal data if necessary.
   TensorLiteralParser literalParser(*this);
-  if (literalParser.parse(/*allowHex=*/true))
-    return nullptr;
-
-  if (parseToken(Token::greater, "expected '>'"))
-    return nullptr;
+  if (!consumeIf(Token::greater)) {
+    if (literalParser.parse(/*allowHex=*/true) ||
+        parseToken(Token::greater, "expected '>'"))
+      return nullptr;
+  }
 
   auto typeLoc = getToken().getLoc();
   auto type = parseElementsLiteralType(attrType);
@@ -841,6 +841,25 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   if (parseToken(Token::less, "Expected '<' after 'sparse'"))
     return nullptr;
 
+  // Check for the case where all elements are sparse. The indices are
+  // represented by a 2-dimensional shape where the second dimension is the rank
+  // of the type.
+  Type indiceEltType = builder.getIntegerType(64);
+  if (consumeIf(Token::greater)) {
+    ShapedType type = parseElementsLiteralType(attrType);
+    if (!type)
+      return nullptr;
+
+    // Construct the sparse elements attr using zero element indice/value
+    // attributes.
+    ShapedType indicesType =
+        RankedTensorType::get({0, type.getRank()}, indiceEltType);
+    ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
+    return SparseElementsAttr::get(
+        type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
+        DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
+  }
+
   /// Parse the indices. We don't allow hex values here as we may need to use
   /// the inferred shape.
   auto indicesLoc = getToken().getLoc();
@@ -869,7 +888,6 @@ Attribute Parser::parseSparseElementsAttr(Type attrType) {
   // 2-dimensional shape where the second dimension is the rank of the type.
   // Given that the parsed indices is a splat, we know that we only have one
   // indice and thus one for the first dimension.
-  auto indiceEltType = builder.getIntegerType(64);
   ShapedType indicesType;
   if (indiceParser.getShape().empty()) {
     indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);

diff  --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir
index bf371fd70f44..300fb7850e33 100644
--- a/mlir/test/IR/parser.mlir
+++ b/mlir/test/IR/parser.mlir
@@ -670,19 +670,21 @@ func @densetensorattr() -> () {
 // CHECK: "fooi67"() {bar = dense<{{\[\[\[}}-5, 4, 6, 2]]]> : vector<1x1x4xi67>} : () -> ()
   "fooi67"(){bar = dense<[[[-5, 4, 6, 2]]]> : vector<1x1x4xi67>} : () -> ()
 
-// CHECK: "foo2"() {bar = dense<[]> : tensor<0xi32>} : () -> ()
-  "foo2"(){bar = dense<[]> : tensor<0xi32>} : () -> ()
-// CHECK: "foo2"() {bar = dense<{{\[\[}}]]> : tensor<1x0xi32>} : () -> ()
-  "foo2"(){bar = dense<[[]]> : tensor<1x0xi32>} : () -> ()
+// CHECK: "foo2"() {bar = dense<> : tensor<0xi32>} : () -> ()
+  "foo2"(){bar = dense<> : tensor<0xi32>} : () -> ()
+// CHECK: "foo2"() {bar = dense<> : tensor<1x0xi32>} : () -> ()
+  "foo2"(){bar = dense<> : tensor<1x0xi32>} : () -> ()
+// CHECK: dense<> : tensor<0x512x512xi32>
+  "foo2"(){bar = dense<> : tensor<0x512x512xi32>} : () -> ()
 // CHECK: "foo3"() {bar = dense<{{\[\[\[}}5, -6, 1, 2]], {{\[\[}}7, 8, 3, 4]]]> : tensor<2x1x4xi32>} : () -> ()
   "foo3"(){bar = dense<[[[5, -6, 1, 2]], [[7, 8, 3, 4]]]> : tensor<2x1x4xi32>} : () -> ()
 
 // CHECK: "float1"() {bar = dense<5.000000e+00> : tensor<1x1x1xf32>} : () -> ()
   "float1"(){bar = dense<[[[5.0]]]> : tensor<1x1x1xf32>} : () -> ()
-// CHECK: "float2"() {bar = dense<[]> : tensor<0xf32>} : () -> ()
-  "float2"(){bar = dense<[]> : tensor<0xf32>} : () -> ()
-// CHECK: "float2"() {bar = dense<{{\[\[}}]]> : tensor<1x0xf32>} : () -> ()
-  "float2"(){bar = dense<[[]]> : tensor<1x0xf32>} : () -> ()
+// CHECK: "float2"() {bar = dense<> : tensor<0xf32>} : () -> ()
+  "float2"(){bar = dense<> : tensor<0xf32>} : () -> ()
+// CHECK: "float2"() {bar = dense<> : tensor<1x0xf32>} : () -> ()
+  "float2"(){bar = dense<> : tensor<1x0xf32>} : () -> ()
 
 // CHECK: "bfloat16"() {bar = dense<{{\[\[\[}}-5.000000e+00, 6.000000e+00, 1.000000e+00, 2.000000e+00]], {{\[\[}}7.000000e+00, -8.000000e+00, 3.000000e+00, 4.000000e+00]]]> : tensor<2x1x4xbf16>} : () -> ()
   "bfloat16"(){bar = dense<[[[-5.0, 6.0, 1.0, 2.0]], [[7.0, -8.0, 3.0, 4.0]]]> : tensor<2x1x4xbf16>} : () -> ()
@@ -752,27 +754,27 @@ func @sparsetensorattr() -> () {
   "fooi8"(){bar = sparse<0, -2> : tensor<1x1x1xi8>} : () -> ()
 // CHECK: "fooi16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]> : tensor<2x2x2xi16>} : () -> ()
   "fooi16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]> : tensor<2x2x2xi16>} : () -> ()
-// CHECK: "fooi32"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<1x1xi32>} : () -> ()
-  "fooi32"(){bar = sparse<[], []> : tensor<1x1xi32>} : () -> ()
+// CHECK: "fooi32"() {bar = sparse<> : tensor<1x1xi32>} : () -> ()
+  "fooi32"(){bar = sparse<> : tensor<1x1xi32>} : () -> ()
 // CHECK: "fooi64"() {bar = sparse<0, -1> : tensor<1xi64>} : () -> ()
   "fooi64"(){bar = sparse<[[0]], [-1]> : tensor<1xi64>} : () -> ()
-// CHECK: "foo2"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<0xi32>} : () -> ()
-  "foo2"(){bar = sparse<[], []> : tensor<0xi32>} : () -> ()
-// CHECK: "foo3"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<i32>} : () -> ()
-  "foo3"(){bar = sparse<[], []> : tensor<i32>} : () -> ()
+// CHECK: "foo2"() {bar = sparse<> : tensor<0xi32>} : () -> ()
+  "foo2"(){bar = sparse<> : tensor<0xi32>} : () -> ()
+// CHECK: "foo3"() {bar = sparse<> : tensor<i32>} : () -> ()
+  "foo3"(){bar = sparse<> : tensor<i32>} : () -> ()
 
 // CHECK: "foof16"() {bar = sparse<0, -2.000000e+00> : tensor<1x1x1xf16>} : () -> ()
   "foof16"(){bar = sparse<0, -2.0> : tensor<1x1x1xf16>} : () -> ()
 // CHECK: "foobf16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2.000000e+00, -1.000000e+00, 5.000000e+00]> : tensor<2x2x2xbf16>} : () -> ()
   "foobf16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2.0, -1.0, 5.0]> : tensor<2x2x2xbf16>} : () -> ()
-// CHECK: "foof32"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<1x0x1xf32>} : () -> ()
-  "foof32"(){bar = sparse<[], []> : tensor<1x0x1xf32>} : () -> ()
+// CHECK: "foof32"() {bar = sparse<> : tensor<1x0x1xf32>} : () -> ()
+  "foof32"(){bar = sparse<> : tensor<1x0x1xf32>} : () -> ()
 // CHECK:  "foof64"() {bar = sparse<0, -1.000000e+00> : tensor<1xf64>} : () -> ()
   "foof64"(){bar = sparse<[[0]], [-1.0]> : tensor<1xf64>} : () -> ()
-// CHECK: "foof320"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<0xf32>} : () -> ()
-  "foof320"(){bar = sparse<[], []> : tensor<0xf32>} : () -> ()
-// CHECK: "foof321"() {bar = sparse<{{\[}}], {{\[}}]> : tensor<f32>} : () -> ()
-  "foof321"(){bar = sparse<[], []> : tensor<f32>} : () -> ()
+// CHECK: "foof320"() {bar = sparse<> : tensor<0xf32>} : () -> ()
+  "foof320"(){bar = sparse<> : tensor<0xf32>} : () -> ()
+// CHECK: "foof321"() {bar = sparse<> : tensor<f32>} : () -> ()
+  "foof321"(){bar = sparse<> : tensor<f32>} : () -> ()
 
 // CHECK: "foostr"() {bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
   "foostr"(){bar = sparse<0, "foo"> : tensor<1x1x1x!unknown<"">>} : () -> ()
@@ -789,8 +791,8 @@ func @sparsevectorattr() -> () {
   "fooi8"(){bar = sparse<0, -2> : vector<1x1x1xi8>} : () -> ()
 // CHECK: "fooi16"() {bar = sparse<{{\[\[}}1, 1, 0], {{\[}}0, 1, 0], {{\[}}0, 0, 1]], {{\[}}2, -1, 5]> : vector<2x2x2xi16>} : () -> ()
   "fooi16"(){bar = sparse<[[1, 1, 0], [0, 1, 0], [0, 0, 1]], [2, -1, 5]> : vector<2x2x2xi16>} : () -> ()
-// CHECK: "fooi32"() {bar = sparse<{{\[}}], {{\[}}]> : vector<1x1xi32>} : () -> ()
-  "fooi32"(){bar = sparse<[], []> : vector<1x1xi32>} : () -> ()
+// CHECK: "fooi32"() {bar = sparse<> : vector<1x1xi32>} : () -> ()
+  "fooi32"(){bar = sparse<> : vector<1x1xi32>} : () -> ()
 // CHECK: "fooi64"() {bar = sparse<0, -1> : vector<1xi64>} : () -> ()
   "fooi64"(){bar = sparse<[[0]], [-1]> : vector<1xi64>} : () -> ()
 


        


More information about the Mlir-commits mailing list