[Mlir-commits] [mlir] 9ad62f6 - [mlir][sparse] remove a few rewriting failures
Aart Bik
llvmlistbot at llvm.org
Wed Nov 18 17:29:58 PST 2020
Author: Aart Bik
Date: 2020-11-18T17:29:40-08:00
New Revision: 9ad62f62b9ad9852fea17a4c81b35e281e45fbaf
URL: https://github.com/llvm/llvm-project/commit/9ad62f62b9ad9852fea17a4c81b35e281e45fbaf
DIFF: https://github.com/llvm/llvm-project/commit/9ad62f62b9ad9852fea17a4c81b35e281e45fbaf.diff
LOG: [mlir][sparse] remove a few rewriting failures
Rationale:
Make sure preconditions are tested already during verfication.
Currently, the only way a sparse rewriting rule can fail is if
(1) the linalg op does not have sparse annotations, or
(2) a yet to be handled operation is encounted inside the op
Reviewed By: penpornk
Differential Revision: https://reviews.llvm.org/D91748
Added:
Modified:
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
mlir/test/Dialect/Linalg/sparse_invalid.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 004643ab176f..ade823fb5c41 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -419,6 +419,8 @@ LogicalResult AnnotationsVerifier<GenericOp>::verify(GenericOp op) {
// Verify consistency of sparse annotations.
if (!op.hasTensorSemantics())
return op.emitOpError("expected sparse annotations on tensors only");
+ if (op.getNumOutputs() != 1)
+ return op.emitOpError("expected single output tensor");
unsigned numTensors = op.getNumInputsAndOutputs();
if (sparseAttr.size() != numTensors)
return op.emitOpError("expected one sparse annotation for each tensor");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index 44d1e40e3453..f449ed3c3343 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -830,22 +830,16 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
LogicalResult matchAndRewrite(linalg::GenericOp op,
PatternRewriter &rewriter) const override {
- unsigned numTensors = op.getNumInputsAndOutputs();
- unsigned numLoops = op.iterator_types().getValue().size();
- Merger merger(numTensors, numLoops);
-
// Detects sparse annotations and translate the per-dimension sparsity
// information for all tensors to loop indices in the kernel.
if (!op.hasSparseSemantics())
return failure();
+ assert(op.getNumOutputs() == 1);
+ unsigned numTensors = op.getNumInputsAndOutputs();
+ unsigned numLoops = op.iterator_types().getValue().size();
+ Merger merger(numTensors, numLoops);
findSparseAnnotations(op, merger.sparse());
- // Accept only single, dense result.
- if (op.getNumOutputs() != 1 ||
- std::any_of(merger.sparse().back().begin(),
- merger.sparse().back().end(), [](bool b) { return b; }))
- return failure();
-
// Computes a topologically sorted iteration graph to ensure
// tensors are visited in natural index order. Fails on cycles.
// This assumes that higher-level passes have already put the
@@ -858,10 +852,7 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
// Finds the terminating yield statement and builds the tensor
// expression for the Linalg operation in SSA form.
- auto ®ion = op.region();
- if (!llvm::hasSingleElement(region))
- return failure(); // single block only
- Operation *yield = region.front().getTerminator();
+ Operation *yield = op.region().front().getTerminator();
Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
if (!exp.hasValue())
return failure(); // build failure
diff --git a/mlir/test/Dialect/Linalg/sparse_invalid.mlir b/mlir/test/Dialect/Linalg/sparse_invalid.mlir
index 985667ce433b..a75ec361a7a1 100644
--- a/mlir/test/Dialect/Linalg/sparse_invalid.mlir
+++ b/mlir/test/Dialect/Linalg/sparse_invalid.mlir
@@ -13,7 +13,7 @@
}
func @invalid_memref(%arga: memref<32xf32>, %argb: f32) -> tensor<32xf32> {
- // expected-error at +1 {{'linalg.generic' op expected sparse annotations on tensors only}}
+ // expected-error at +1 {{'linalg.generic' op expected sparse annotations on tensors only}}
%0 = linalg.generic #trait_memref
ins(%arga: memref<32xf32>) {
^bb(%a: f32):
@@ -25,6 +25,79 @@ func @invalid_memref(%arga: memref<32xf32>, %argb: f32) -> tensor<32xf32> {
// -----
+#trait_two_out = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // a
+ affine_map<(i) -> (i)>, // x (out)
+ affine_map<(i) -> (i)> // y (out)
+ ],
+ sparse = [
+ [ "S" ], // a
+ [ "D" ], // x
+ [ "D" ] // y
+ ],
+ iterator_types = ["parallel"]
+}
+
+func @invalid_two_out(%arga: tensor<32xf32>) -> tensor<32xf32> {
+ // expected-error at +1 {{'linalg.generic' op expected single output tensor}}
+ %0, %1 = linalg.generic #trait_two_out
+ ins(%arga: tensor<32xf32>) {
+ ^bb(%a: f32):
+ %0 = addf %a, %a : f32
+ linalg.yield %a, %0 : f32, f32
+ } -> tensor<32xf32>, tensor<32xf32>
+ return %1 : tensor<32xf32>
+}
+
+// -----
+
+#trait_two_blocks = {
+ indexing_maps = [
+ affine_map<(i) -> (i)>, // a
+ affine_map<(i) -> (i)> // x (out)
+ ],
+ sparse = [
+ [ "S" ], // a
+ [ "D" ] // x
+ ],
+ iterator_types = ["parallel"]
+}
+
+func @invalid_two_blocks(%arga: tensor<32xf32>) -> tensor<32xf32> {
+ // expected-error at +1 {{'linalg.generic' op expects region #0 to have 0 or 1 blocks}}
+ %0 = linalg.generic #trait_two_blocks
+ ins(%arga: tensor<32xf32>) {
+ ^bb1(%a: f32):
+ %0 = addf %a, %a : f32
+ ^bb2:
+ linalg.yield %0 : f32
+ } -> tensor<32xf32>
+ return %0 : tensor<32xf32>
+}
+
+// -----
+
+#trait_no_block = {
+ indexing_maps = [
+ affine_map<(i) -> (i)> // a
+ ],
+ sparse = [
+ [ "S" ] // a
+ ],
+ iterator_types = ["parallel"]
+}
+
+func @invalid_no_block(%arga: tensor<32xf32>) {
+ // expected-error at +1 {{'linalg.generic' op expected region with 1 block}}
+ linalg.generic #trait_no_block
+ ins(%arga: tensor<32xf32>) {
+ }
+ return
+}
+
+// -----
+
#trait_too_many = {
indexing_maps = [
affine_map<(i) -> (i)>, // a
@@ -39,7 +112,7 @@ func @invalid_memref(%arga: memref<32xf32>, %argb: f32) -> tensor<32xf32> {
}
func @invalid_too_many(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
- // expected-error at +1 {{'linalg.generic' op expected one sparse annotation for each tensor}}
+ // expected-error at +1 {{'linalg.generic' op expected one sparse annotation for each tensor}}
%0 = linalg.generic #trait_too_many
ins(%arga: tensor<32xf32>) {
^bb(%a: f32):
@@ -61,7 +134,7 @@ func @invalid_too_many(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
}
func @invalid_no_array(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
- // expected-error at +1 {{'linalg.generic' op expected sparse annotation array for tensor 0}}
+ // expected-error at +1 {{'linalg.generic' op expected sparse annotation array for tensor 0}}
%0 = linalg.generic #trait_no_array
ins(%arga: tensor<32xf32>) {
^bb(%a: f32):
@@ -86,7 +159,7 @@ func @invalid_no_array(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
}
func @invalid_wrong_rank(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
- // expected-error at +1 {{'linalg.generic' op expected sparse annotation with rank 1 for tensor 1}}
+ // expected-error at +1 {{'linalg.generic' op expected sparse annotation with rank 1 for tensor 1}}
%0 = linalg.generic #trait_wrong_rank
ins(%arga: tensor<32xf32>) {
^bb(%a: f32):
@@ -111,7 +184,7 @@ func @invalid_wrong_rank(%arga: tensor<32xf32>, %argb: f32) -> tensor<32xf32> {
}
func @invalid_no_string(%arga: tensor<32x16xf32>, %argb: f32) -> tensor<32x16xf32> {
- // expected-error at +1 {{'linalg.generic' op expected sparse annotation at position 1 for tensor 0}}
+ // expected-error at +1 {{'linalg.generic' op expected sparse annotation at position 1 for tensor 0}}
%0 = linalg.generic #trait_no_string
ins(%arga: tensor<32x16xf32>) {
^bb(%a: f32):
More information about the Mlir-commits
mailing list