[Mlir-commits] [mlir] bf9ef3e - [mlir][sparse] skip sparsification for unannotated (or unhandled) cases
Aart Bik
llvmlistbot at llvm.org
Wed May 19 13:49:42 PDT 2021
Author: Aart Bik
Date: 2021-05-19T13:49:28-07:00
New Revision: bf9ef3efaa99c02e7bfc4c57207301b8de39a278
URL: https://github.com/llvm/llvm-project/commit/bf9ef3efaa99c02e7bfc4c57207301b8de39a278
DIFF: https://github.com/llvm/llvm-project/commit/bf9ef3efaa99c02e7bfc4c57207301b8de39a278.diff
LOG: [mlir][sparse] skip sparsification for unannotated (or unhandled) cases
Skip the sparsification pass for Linalg ops without annotated tensors
(or cases that are not properly handled yet).
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D102787
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
mlir/test/Dialect/SparseTensor/sparse_3d.mlir
mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 7348a7391f3ae..6fb90dcc645f8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -347,17 +347,27 @@ static Dim toDim(SparseTensorEncodingAttr &enc, unsigned d) {
/// Helper method to inspect sparse encodings in the tensor types.
/// Fills the per-dimension sparsity information for all tensors.
-static void findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
+static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) {
+ bool annotated = false;
unsigned numTensors = op.getNumShapedOperands();
+ unsigned lhs = numTensors - 1;
for (unsigned t = 0; t < numTensors; t++) {
auto map = op.getIndexingMap(t);
unsigned rank = op.getShapedType(t).getRank();
auto enc = getSparseTensorEncoding(op.getShapedType(t));
+ if (enc) {
+ annotated = true;
+ if (enc.getDimOrdering() && !enc.getDimOrdering().isIdentity())
+ return false; // TODO: handle permutations
+ if (t == lhs)
+ return false; // TODO: handle sparse outputs
+ }
for (unsigned d = 0; d < rank; d++) {
unsigned idx = map.getDimPosition(d);
merger.setDim(t, idx, toDim(enc, d));
}
}
+ return annotated;
}
/// A DFS helper to compute a topological sort. Note that recursion is
@@ -1356,7 +1366,8 @@ struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
unsigned numTensors = op.getNumShapedOperands();
unsigned numLoops = op.iterator_types().getValue().size();
Merger merger(numTensors, numLoops);
- findSparseAnnotations(merger, op);
+ if (!findSparseAnnotations(merger, op))
+ return failure();
// Computes a topologically sorted iteration graph to ensure
// tensors are visited in natural index order. Fails on cycles.
diff --git a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
index 1c636fee6dc8b..afdcd32c3ee70 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_3d.mlir
@@ -1,6 +1,8 @@
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// RUN: mlir-opt %s -sparsification | FileCheck %s
+#Td = #sparse_tensor.encoding<{ dimLevelType = [ "dense" ] }>
+
#Tddd = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "dense" ] }>
#Tdds = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "dense", "compressed" ] }>
#Tdsd = #sparse_tensor.encoding<{ dimLevelType = [ "dense", "compressed", "dense" ] }>
@@ -1249,7 +1251,7 @@ func @sum_reduction(%arga: tensor<10x20x30xf32, #Tsss>, %argx: tensor<f32>) -> t
// CHECK-LABEL: func @sum_reduction_inv(
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xf32>,
-// CHECK-SAME: %[[VAL_1:.*]]: tensor<?xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<f32>) -> tensor<f32> {
// CHECK: %[[VAL_3:.*]] = constant 2 : index
// CHECK: %[[VAL_4:.*]] = constant 0 : index
@@ -1257,8 +1259,8 @@ func @sum_reduction(%arga: tensor<10x20x30xf32, #Tsss>, %argx: tensor<f32>) -> t
// CHECK: %[[VAL_6:.*]] = memref.dim %[[VAL_0]], %[[VAL_5]] : tensor<?x?x?xf32>
// CHECK: %[[VAL_7:.*]] = memref.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?x?xf32>
// CHECK: %[[VAL_8:.*]] = memref.buffer_cast %[[VAL_0]] : memref<?x?x?xf32>
-// CHECK: %[[VAL_9:.*]] = memref.dim %[[VAL_1]], %[[VAL_4]] : tensor<?xf32>
-// CHECK: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<?xf32>
+// CHECK: %[[VAL_9:.*]] = memref.dim %[[VAL_1]], %[[VAL_4]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>
+// CHECK: %[[VAL_10:.*]] = sparse_tensor.values %[[VAL_1]] : tensor<?xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
// CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<f32>
// CHECK: %[[VAL_12:.*]] = memref.alloc() : memref<f32>
// CHECK: linalg.copy(%[[VAL_11]], %[[VAL_12]]) : memref<f32>, memref<f32>
@@ -1279,10 +1281,10 @@ func @sum_reduction(%arga: tensor<10x20x30xf32, #Tsss>, %argx: tensor<f32>) -> t
// CHECK: return %[[VAL_24]] : tensor<f32>
// CHECK: }
func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
- %argb: tensor<?xf32>,
+ %argb: tensor<?xf32, #Td>,
%argx: tensor<f32>) -> tensor<f32> {
%0 = linalg.generic #trait_sum_reduction_inv
- ins(%arga, %argb: tensor<?x?x?xf32>, tensor<?xf32>)
+ ins(%arga, %argb: tensor<?x?x?xf32>, tensor<?xf32, #Td>)
outs(%argx: tensor<f32>) {
^bb(%a: f32, %b: f32, %x: f32):
%0 = mulf %a, %b : f32
@@ -1304,7 +1306,7 @@ func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
}
// CHECK-LABEL: func @invariants(
-// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32>,
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>>,
// CHECK-SAME: %[[VAL_1:.*]]: tensor<20xf32>,
// CHECK-SAME: %[[VAL_2:.*]]: tensor<30xf32>,
// CHECK-SAME: %[[VAL_3:.*]]: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> {
@@ -1313,14 +1315,14 @@ func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
// CHECK: %[[VAL_6:.*]] = constant 30 : index
// CHECK: %[[VAL_7:.*]] = constant 0 : index
// CHECK: %[[VAL_8:.*]] = constant 1 : index
-// CHECK: %[[VAL_9:.*]] = memref.buffer_cast %[[VAL_0]] : memref<10xf32>
+// CHECK: %[[VAL_9:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<10xf32, #sparse_tensor.encoding<{ dimLevelType = [ "dense" ], pointerBitWidth = 0, indexBitWidth = 0 }>> to memref<?xf32>
// CHECK: %[[VAL_10:.*]] = memref.buffer_cast %[[VAL_1]] : memref<20xf32>
// CHECK: %[[VAL_11:.*]] = memref.buffer_cast %[[VAL_2]] : memref<30xf32>
// CHECK: %[[VAL_12:.*]] = memref.buffer_cast %[[VAL_3]] : memref<10x20x30xf32>
// CHECK: %[[VAL_13:.*]] = memref.alloc() : memref<10x20x30xf32>
// CHECK: linalg.copy(%[[VAL_12]], %[[VAL_13]]) : memref<10x20x30xf32>, memref<10x20x30xf32>
// CHECK: scf.for %[[VAL_14:.*]] = %[[VAL_7]] to %[[VAL_4]] step %[[VAL_8]] {
-// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<10xf32>
+// CHECK: %[[VAL_15:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_14]]] : memref<?xf32>
// CHECK: scf.for %[[VAL_16:.*]] = %[[VAL_7]] to %[[VAL_5]] step %[[VAL_8]] {
// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_16]]] : memref<20xf32>
// CHECK: scf.for %[[VAL_18:.*]] = %[[VAL_7]] to %[[VAL_6]] step %[[VAL_8]] {
@@ -1334,12 +1336,12 @@ func @sum_reduction_inv(%arga: tensor<?x?x?xf32>,
// CHECK: %[[VAL_22:.*]] = memref.tensor_load %[[VAL_13]] : memref<10x20x30xf32>
// CHECK: return %[[VAL_22]] : tensor<10x20x30xf32>
// CHECK: }
-func @invariants(%arga: tensor<10xf32>,
+func @invariants(%arga: tensor<10xf32, #Td>,
%argb: tensor<20xf32>,
%argc: tensor<30xf32>,
%argx: tensor<10x20x30xf32>) -> tensor<10x20x30xf32> {
%0 = linalg.generic #trait_invariants
- ins(%arga, %argb, %argc : tensor<10xf32>, tensor<20xf32>, tensor<30xf32>)
+ ins(%arga, %argb, %argc : tensor<10xf32, #Td>, tensor<20xf32>, tensor<30xf32>)
outs(%argx: tensor<10x20x30xf32>) {
^bb(%a: f32, %b: f32, %c: f32, %x: f32):
%0 = mulf %a, %b : f32
diff --git a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
index fc575141e311a..a48d035dcafda 100644
--- a/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_parallel.mlir
@@ -9,6 +9,10 @@
// RUN: mlir-opt %s -sparsification="parallelization-strategy=4" | \
// RUN: FileCheck %s --check-prefix=CHECK-PAR4
+#DenseMatrix = #sparse_tensor.encoding<{
+ dimLevelType = [ "dense", "dense" ]
+}>
+
#SparseMatrix = #sparse_tensor.encoding<{
dimLevelType = [ "compressed", "compressed" ]
}>
@@ -52,9 +56,11 @@
// CHECK-PAR4: scf.parallel
// CHECK-PAR4: return
//
-func @scale_dd(%scale: f32, %arga: tensor<?x?xf32>, %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func @scale_dd(%scale: f32,
+ %arga: tensor<?x?xf32, #DenseMatrix>,
+ %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic #trait_dd
- ins(%arga: tensor<?x?xf32>)
+ ins(%arga: tensor<?x?xf32, #DenseMatrix>)
outs(%argx: tensor<?x?xf32>) {
^bb(%a: f32, %x: f32):
%0 = mulf %a, %scale : f32
@@ -98,7 +104,9 @@ func @scale_dd(%scale: f32, %arga: tensor<?x?xf32>, %argx: tensor<?x?xf32>) -> t
// CHECK-PAR4: scf.parallel
// CHECK-PAR4: return
//
-func @scale_ss(%scale: f32, %arga: tensor<?x?xf32, #SparseMatrix>, %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
+func @scale_ss(%scale: f32,
+ %arga: tensor<?x?xf32, #SparseMatrix>,
+ %argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
%0 = linalg.generic #trait_ss
ins(%arga: tensor<?x?xf32, #SparseMatrix>)
outs(%argx: tensor<?x?xf32>) {
@@ -145,9 +153,11 @@ func @scale_ss(%scale: f32, %arga: tensor<?x?xf32, #SparseMatrix>, %argx: tensor
// CHECK-PAR4: scf.for
// CHECK-PAR4: return
//
-func @matvec(%argA: tensor<16x32xf32, #CSR>, %argb: tensor<32xf32>, %argx: tensor<16xf32>) -> tensor<16xf32> {
+func @matvec(%arga: tensor<16x32xf32, #CSR>,
+ %argb: tensor<32xf32>,
+ %argx: tensor<16xf32>) -> tensor<16xf32> {
%0 = linalg.generic #trait_matvec
- ins(%argA, %argb : tensor<16x32xf32, #CSR>, tensor<32xf32>)
+ ins(%arga, %argb : tensor<16x32xf32, #CSR>, tensor<32xf32>)
outs(%argx: tensor<16xf32>) {
^bb(%A: f32, %b: f32, %x: f32):
%0 = mulf %A, %b : f32
More information about the Mlir-commits
mailing list