[Mlir-commits] [mlir] 9ad62f6 - [mlir][sparse] remove a few rewriting failures

Aart Bik llvmlistbot at llvm.org
Wed Nov 18 17:29:58 PST 2020


Author: Aart Bik
Date: 2020-11-18T17:29:40-08:00
New Revision: 9ad62f62b9ad9852fea17a4c81b35e281e45fbaf

URL: https://github.com/llvm/llvm-project/commit/9ad62f62b9ad9852fea17a4c81b35e281e45fbaf
DIFF: https://github.com/llvm/llvm-project/commit/9ad62f62b9ad9852fea17a4c81b35e281e45fbaf.diff

LOG: [mlir][sparse] remove a few rewriting failures

Rationale:
Make sure preconditions are tested already during verfication.
Currently, the only way a sparse rewriting rule can fail is if
(1) the linalg op does not have sparse annotations, or
(2) a yet to be handled operation is encounted inside the op

Reviewed By: penpornk

Differential Revision: https://reviews.llvm.org/D91748

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
    mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
    mlir/test/Dialect/Linalg/sparse_invalid.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 004643ab176f..ade823fb5c41 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -419,6 +419,8 @@ LogicalResult AnnotationsVerifier<GenericOp>::verify(GenericOp op) {
   // Verify consistency of sparse annotations.
   if (!op.hasTensorSemantics())
     return op.emitOpError("expected sparse annotations on tensors only");
+  if (op.getNumOutputs() != 1)
+    return op.emitOpError("expected single output tensor");
   unsigned numTensors = op.getNumInputsAndOutputs();
   if (sparseAttr.size() != numTensors)
     return op.emitOpError("expected one sparse annotation for each tensor");

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 44d1e40e3453..f449ed3c3343 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -830,22 +830,16 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
 
   LogicalResult matchAndRewrite(linalg::GenericOp op,
                                 PatternRewriter &rewriter) const override {
-    unsigned numTensors = op.getNumInputsAndOutputs();
-    unsigned numLoops = op.iterator_types().getValue().size();
-    Merger merger(numTensors, numLoops);
-
     // Detects sparse annotations and translate the per-dimension sparsity
     // information for all tensors to loop indices in the kernel.
     if (!op.hasSparseSemantics())
       return failure();
+    assert(op.getNumOutputs() == 1);
+    unsigned numTensors = op.getNumInputsAndOutputs();
+    unsigned numLoops = op.iterator_types().getValue().size();
+    Merger merger(numTensors, numLoops);
     findSparseAnnotations(op, merger.sparse());
 
-    // Accept only single, dense result.
-    if (op.getNumOutputs() != 1 ||
-        std::any_of(merger.sparse().back().begin(),
-                    merger.sparse().back().end(), [](bool b) { return b; }))
-      return failure();
-
     // Computes a topologically sorted iteration graph to ensure
     // tensors are visited in natural index order. Fails on cycles.
     // This assumes that higher-level passes have already put the
@@ -858,10 +852,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
 
     // Finds the terminating yield statement and builds the tensor
     // expression for the Linalg operation in SSA form.
-    auto &region = op.region();
-    if (!llvm::hasSingleElement(region))
-      return failure(); // single block only
-    Operation *yield = region.front().getTerminator();
+    Operation *yield = op.region().front().getTerminator();
     Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
     if (!exp.hasValue())
       return failure(); // build failure

diff  --git a/mlir/test/Dialect/Linalg/sparse_invalid.mlir b/mlir/test/Dialect/Linalg/sparse_invalid.mlir
index 985667ce433b..a75ec361a7a1 100644
--- a/mlir/test/Dialect/Linalg/sparse_invalid.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_invalid.mlir
@@ -13,7 +13,7 @@
 }
 
 func @invalid_memref(%arga: memref<32xf32>, %argb: f32) -> tensor<32xf32> {
- // expected-error at +1 {{'linalg.generic' op expected sparse annotations on tensors only}}
+  // expected-error at +1 {{'linalg.generic' op expected sparse annotations on tensors only}}
   %0 = linalg.generic #trait_memref
     ins(%arga: memref<32xf32>) {
       ^bb(%a: f32):
@@ -25,6 +25,79 @@ func @invalid_memref(%arga: memref<32xf32>, %argb: f32) -> tensor<32xf32> {
 
 // -----
 
+#trait_two_out = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a
+    affine_map<(i) -> (i)>,  // x (out)
+    affine_map<(i) -> (i)>   // y (out)
+  ],
+  sparse = [
+    [ "S" ],  // a
+    [ "D" ],  // x
+    [ "D" ]   // y
+  ],
+  iterator_types = ["parallel"]
+}
+
+func @invalid_two_out(%arga: tensor<32xf32>) -> tensor<32xf32> {
+  // expected-error at +1 {{'linalg.generic' op expected single output tensor}}
+  %0, %1 = linalg.generic #trait_two_out
+    ins(%arga: tensor<32xf32>) {
+      ^bb(%a: f32):
+        %0 = addf %a, %a : f32
+        linalg.yield %a, %0 : f32, f32
+  } -> tensor<32xf32>, tensor<32xf32>
+  return %1 : tensor<32xf32>
+}
+
+// -----
+
+#trait_two_blocks = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>,  // a
+    affine_map<(i) -> (i)>   // x (out)
+  ],
+  sparse = [
+    [ "S" ],  // a
+    [ "D" ]   // x
+  ],
+  iterator_types = ["parallel"]
+}
+
+func @invalid_two_blocks(%arga: tensor<32xf32>) -> tensor<32xf32> {
+  // expected-error at +1 {{'linalg.generic' op expects region #0 to have 0 or 1 blocks}}
+  %0 = linalg.generic #trait_two_blocks
+    ins(%arga: tensor<32xf32>) {
+      ^bb1(%a: f32):
+        %0 = addf %a, %a : f32
+      ^bb2:
+        linalg.yield %0 : f32
+  } -> tensor<32xf32>
+  return %0 : tensor<32xf32>
+}
+
+// -----
+
+#trait_no_block = {
+  indexing_maps = [
+    affine_map<(i) -> (i)>  // a
+  ],
+  sparse = [
+    [ "S" ]  // a
+  ],
+  iterator_types = ["parallel"]
+}
+
+func @invalid_no_block(%arga: tensor<32xf32>) {
+  // expected-error at +1 {{'linalg.generic' op expected region with 1 block}}
+  linalg.generic #trait_no_block
+    ins(%arga: tensor<32xf32>) {
+    }
+  return
+}
+
+// -----
+
 #trait_too_many = {
   indexing_maps = [
     affine_map<(i) -> (i)>,  // a
@@ -39,7 +112,7 @@ func @invalid_memref(%arga: memref<32xf32>, %argb: f32) -> tensor<32xf32> {
 }
 
 func @invalid_too_many(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
- // expected-error at +1 {{'linalg.generic' op expected one sparse annotation for each tensor}}
+  // expected-error at +1 {{'linalg.generic' op expected one sparse annotation for each tensor}}
   %0 = linalg.generic #trait_too_many
     ins(%arga: tensor<32xf32>) {
       ^bb(%a: f32):
@@ -61,7 +134,7 @@ func @invalid_too_many(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
 }
 
 func @invalid_no_array(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
- // expected-error at +1 {{'linalg.generic' op expected sparse annotation array for tensor 0}}
+  // expected-error at +1 {{'linalg.generic' op expected sparse annotation array for tensor 0}}
   %0 = linalg.generic #trait_no_array
     ins(%arga: tensor<32xf32>) {
       ^bb(%a: f32):
@@ -86,7 +159,7 @@ func @invalid_no_array(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
 }
 
 func @invalid_wrong_rank(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
- // expected-error at +1 {{'linalg.generic' op expected sparse annotation with rank 1 for tensor 1}}
+  // expected-error at +1 {{'linalg.generic' op expected sparse annotation with rank 1 for tensor 1}}
   %0 = linalg.generic #trait_wrong_rank
     ins(%arga: tensor<32xf32>) {
       ^bb(%a: f32):
@@ -111,7 +184,7 @@ func @invalid_wrong_rank(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
 }
 
 func @invalid_no_string(%arga: tensor<32x16xf32>, %argb: f32) -> tensor<32x16xf32> {
- // expected-error at +1 {{'linalg.generic' op expected sparse annotation at position 1 for tensor 0}}
+  // expected-error at +1 {{'linalg.generic' op expected sparse annotation at position 1 for tensor 0}}
   %0 = linalg.generic #trait_no_string
     ins(%arga: tensor<32x16xf32>) {
       ^bb(%a: f32):


        


More information about the Mlir-commits mailing list