[Mlir-commits] [mlir] e6cbb91 - [mlir][sparse] skip zeros during dense2sparse
Aart Bik
llvmlistbot at llvm.org
Wed Nov 9 20:54:36 PST 2022
Author: Aart Bik
Date: 2022-11-09T20:54:27-08:00
New Revision: e6cbb9148366ea4ac35e34bfe351394560934187
URL: https://github.com/llvm/llvm-project/commit/e6cbb9148366ea4ac35e34bfe351394560934187
DIFF: https://github.com/llvm/llvm-project/commit/e6cbb9148366ea4ac35e34bfe351394560934187.diff
LOG: [mlir][sparse] skip zeros during dense2sparse
This enables the full matmul integration test with runtime_lib=true/false!
Background:
https://github.com/llvm/llvm-project/issues/51657
Reviewed By: bixia
Differential Revision: https://reviews.llvm.org/D137750
Added:
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index 701be4c163ae6..57230ebfac535 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -533,6 +533,13 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
SmallVector<Value, 4> dynSizes;
getDynamicSizes(dstTp, sizes, dynSizes);
+ bool fromSparseConst = false;
+ if (auto constOp = op.getSource().getDefiningOp<arith::ConstantOp>()) {
+ if (constOp.getValue().dyn_cast<SparseElementsAttr>()) {
+ fromSparseConst = true;
+ }
+ }
+
RankedTensorType cooTp = getUnorderedCOOFromType(dstTp);
auto cooBuffer =
rewriter.create<AllocTensorOp>(loc, cooTp, dynSizes).getResult();
@@ -540,8 +547,22 @@ struct ConvertRewriter : public OpRewritePattern<ConvertOp> {
loc, src, cooBuffer,
[&](OpBuilder &builder, Location loc, ValueRange indices, Value v,
ValueRange reduc) {
- builder.create<sparse_tensor::YieldOp>(
- loc, builder.create<InsertOp>(loc, v, reduc.front(), indices));
+ Value input = reduc.front();
+ if (fromSparseConst) {
+ input = builder.create<InsertOp>(loc, v, input, indices);
+ } else {
+ Value cond = genIsNonzero(builder, loc, v);
+ auto ifOp = builder.create<scf::IfOp>(
+ loc, TypeRange(input.getType()), cond, /*else*/ true);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+ Value insert = builder.create<InsertOp>(loc, v, input, indices);
+ builder.create<scf::YieldOp>(loc, insert);
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ builder.create<scf::YieldOp>(loc, input);
+ builder.setInsertionPointAfter(ifOp);
+ input = ifOp.getResult(0);
+ }
+ builder.create<sparse_tensor::YieldOp>(loc, input);
});
rewriter.setInsertionPointAfter(op);
src = rewriter.create<LoadOp>(loc, foreachOp.getResult(0), true);
diff --git a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
index 96d78874a19cc..7c80314f83219 100644
--- a/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
+++ b/mlir/test/Dialect/SparseTensor/convert_dense2sparse.mlir
@@ -109,8 +109,14 @@ func.func @sparse_convert_complex(%arg0: tensor<100xcomplex<f64>>) -> tensor<100
// CHECK-RWT: %[[VAL_2:.*]] = bufferization.alloc_tensor()
// CHECK-RWT: %[[VAL_3:.*]] = sparse_tensor.foreach in %[[T0]] init(%[[VAL_2]])
// CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f64, %[[VAL_7:.*]]: tensor
-// CHECK-RWT: %[[VAL_8:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]]
-// CHECK-RWT: sparse_tensor.yield %[[VAL_8]]
+// CHECK-RWT: %[[CMP:.*]] = arith.cmpf une, %[[VAL_6]]
+// CHECK-RWT: %[[IFR:.*]] = scf.if %[[CMP]]
+// CHECK-RWT: %[[Y1:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]
+// CHECK-RWT: scf.yield %[[Y1]]
+// CHECK-RWT: } else {
+// CHECK-RWT: scf.yield %[[VAL_7]]
+// CHECK-RWT: }
+// CHECK-RWT: sparse_tensor.yield %[[IFR]]
// CHECK-RWT: }
// CHECK-RWT: %[[COO:.*]] = sparse_tensor.load %[[VAL_3]] hasInserts
// CHECK-RWT: %[[I0:.*]] = sparse_tensor.indices %[[COO]] {dimension = 0 : index}
@@ -166,7 +172,7 @@ func.func @sparse_convert_2d(%arg0: tensor<2x4xf64>) -> tensor<2x4xf64, #CSR> {
// CHECK-RWT: %[[VAL_0:.*]] = arith.constant 1 : index
// CHECK-RWT: %[[VAL_1:.*]] = arith.constant sparse<{{\[\[}}0, 0], [1, 6]], [1.000000e+00, 5.000000e+00]> : tensor<8x7xf32>
// CHECK-RWT: %[[T0:.*]] = bufferization.alloc_tensor()
-// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]])
+// CHECK-RWT: %[[T1:.*]] = sparse_tensor.foreach in %[[VAL_1]] init(%[[T0]])
// CHECK-RWT: ^bb0(%[[VAL_4:.*]]: index, %[[VAL_5:.*]]: index, %[[VAL_6:.*]]: f32, %[[VAL_7:.*]]: tensor
// CHECK-RWT: %[[T2:.*]] = sparse_tensor.insert %[[VAL_6]] into %[[VAL_7]]{{\[}}%[[VAL_4]], %[[VAL_5]]]
// CHECK-RWT: sparse_tensor.yield %[[T2]]
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
index 88238e956a7b3..ce930eb7feec5 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_matmul.mlir
@@ -116,6 +116,35 @@ module {
%b3 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #CSR>
%b4 = sparse_tensor.convert %sb : tensor<8x4xf64> to tensor<8x4xf64, #DCSR>
+ //
+ // Sanity check on stored entries before going into the computations.
+ //
+ // CHECK: 32
+ // CHECK-NEXT: 32
+ // CHECK-NEXT: 4
+ // CHECK-NEXT: 4
+ // CHECK-NEXT: 32
+ // CHECK-NEXT: 32
+ // CHECK-NEXT: 8
+ // CHECK-NEXT: 8
+ //
+ %noea1 = sparse_tensor.number_of_entries %a1 : tensor<4x8xf64, #CSR>
+ %noea2 = sparse_tensor.number_of_entries %a2 : tensor<4x8xf64, #DCSR>
+ %noea3 = sparse_tensor.number_of_entries %a3 : tensor<4x8xf64, #CSR>
+ %noea4 = sparse_tensor.number_of_entries %a4 : tensor<4x8xf64, #DCSR>
+ %noeb1 = sparse_tensor.number_of_entries %b1 : tensor<8x4xf64, #CSR>
+ %noeb2 = sparse_tensor.number_of_entries %b2 : tensor<8x4xf64, #DCSR>
+ %noeb3 = sparse_tensor.number_of_entries %b3 : tensor<8x4xf64, #CSR>
+ %noeb4 = sparse_tensor.number_of_entries %b4 : tensor<8x4xf64, #DCSR>
+ vector.print %noea1 : index
+ vector.print %noea2 : index
+ vector.print %noea3 : index
+ vector.print %noea4 : index
+ vector.print %noeb1 : index
+ vector.print %noeb2 : index
+ vector.print %noeb3 : index
+ vector.print %noeb4 : index
+
// Call kernels with dense.
%0 = call @matmul1(%da, %db, %zero)
: (tensor<4x8xf64>, tensor<8x4xf64>, tensor<4x4xf64>) -> tensor<4x4xf64>
@@ -205,20 +234,20 @@ module {
vector.print %v5 : vector<4x4xf64>
//
- // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+ // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
//
%v6 = vector.transfer_read %6[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
vector.print %v6 : vector<4x4xf64>
//
- // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+ // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
//
%c7 = sparse_tensor.convert %7 : tensor<4x4xf64, #CSR> to tensor<4x4xf64>
%v7 = vector.transfer_read %c7[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
vector.print %v7 : vector<4x4xf64>
//
- // CHECK: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
+ // CHECK-NEXT: ( ( 0, 30.5, 4.2, 0 ), ( 0, 0, 0, 0 ), ( 0, 0, 4.6, 0 ), ( 0, 0, 7, 8 ) )
//
%c8 = sparse_tensor.convert %8 : tensor<4x4xf64, #DCSR> to tensor<4x4xf64>
%v8 = vector.transfer_read %c8[%c0, %c0], %d1 : tensor<4x4xf64>, vector<4x4xf64>
@@ -227,17 +256,26 @@ module {
//
// Sanity check on nonzeros.
//
- // FIXME: bring this back once dense2sparse skips zeros
- //
- // C_HECK: ( 30.5, 4.2, 4.6, 7, 8 )
- // C_HECK: ( 30.5, 4.2, 4.6, 7, 8 )
+ // CHECK-NEXT: ( 30.5, 4.2, 4.6, 7, 8 )
+ // CHECK-NEXT: ( 30.5, 4.2, 4.6, 7, 8 )
//
%val7 = sparse_tensor.values %7 : tensor<4x4xf64, #CSR> to memref<?xf64>
%val8 = sparse_tensor.values %8 : tensor<4x4xf64, #DCSR> to memref<?xf64>
- %nz7 = vector.transfer_read %val7[%c0], %d1 : memref<?xf64>, vector<8xf64>
- %nz8 = vector.transfer_read %val8[%c0], %d1 : memref<?xf64>, vector<8xf64>
- vector.print %nz7 : vector<8xf64>
- vector.print %nz8 : vector<8xf64>
+ %nz7 = vector.transfer_read %val7[%c0], %d1 : memref<?xf64>, vector<5xf64>
+ %nz8 = vector.transfer_read %val8[%c0], %d1 : memref<?xf64>, vector<5xf64>
+ vector.print %nz7 : vector<5xf64>
+ vector.print %nz8 : vector<5xf64>
+
+ //
+ // Sanity check on stored entries after the computations.
+ //
+ // CHECK-NEXT: 5
+ // CHECK-NEXT: 5
+ //
+ %noe7 = sparse_tensor.number_of_entries %7 : tensor<4x4xf64, #CSR>
+ %noe8 = sparse_tensor.number_of_entries %8 : tensor<4x4xf64, #DCSR>
+ vector.print %noe7 : index
+ vector.print %noe8 : index
// Release the resources.
bufferization.dealloc_tensor %a1 : tensor<4x8xf64, #CSR>
More information about the Mlir-commits
mailing list