[Mlir-commits] [mlir] ce3d0e8 - [mlir][sparse] enable SDDMM-flavored fusion
Aart Bik
llvmlistbot at llvm.org
Tue Aug 2 12:40:14 PDT 2022
Author: Aart Bik
Date: 2022-08-02T12:40:04-07:00
New Revision: ce3d0e87ac23ce4c2be8a3b99c3020f930d7ba16
URL: https://github.com/llvm/llvm-project/commit/ce3d0e87ac23ce4c2be8a3b99c3020f930d7ba16
DIFF: https://github.com/llvm/llvm-project/commit/ce3d0e87ac23ce4c2be8a3b99c3020f930d7ba16.diff
LOG: [mlir][sparse] enable SDDMM-flavored fusion
This rewriting was no longer functional after recent migration to one shot
bufferization. However, this revision makes it work again, with a CHECK test
to ensure fusion happens. Note that functionality is tested by several
integration tests.
Reviewed By: Peiming
Differential Revision: https://reviews.llvm.org/D130996
Added:
mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
Modified:
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index a01d7c86ada18..f0aef0f0f6386 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -41,12 +41,17 @@ static bool isSparseTensor(OpOperand *op) {
return false;
}
-// Helper method to find zero or empty initialization.
-static bool isEmptyInit(OpOperand *op) {
+// Helper method to find zero/uninitialized allocation.
+static bool isAlloc(OpOperand *op, bool isZero) {
Value val = op->get();
- return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
- val.getDefiningOp<InitTensorOp>() ||
- val.getDefiningOp<AllocTensorOp>();
+ if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
+ Value copy = alloc.getCopy();
+ if (isZero)
+ return copy && (matchPattern(copy, m_Zero()) ||
+ matchPattern(copy, m_AnyZeroFloat()));
+ return !copy;
+ }
+ return false;
}
// Helper to detect sampling operation.
@@ -140,9 +145,9 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
!prod.getResult(0).hasOneUse())
return failure();
// Sampling consumer and sum of multiplication chain producer.
- if (!isEmptyInit(op.getOutputOperand(0)) ||
- !isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) ||
- !isSumOfMul(prod))
+ if (!isAlloc(op.getOutputOperand(0), /*isZero=*/false) ||
+ !isAlloc(prod.getOutputOperand(0), /*isZero=*/true) ||
+ !isSampling(op) || !isSumOfMul(prod))
return failure();
// Modify operand structure of producer and consumer.
Location loc = prod.getLoc();
@@ -180,6 +185,14 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
last = rewriter.clone(*acc, mapper)->getResult(0);
rewriter.create<linalg::YieldOp>(loc, last);
+ // Force initial value on merged allocation for dense outputs.
+ if (!getSparseTensorEncoding(op.getResult(0).getType())) {
+ AllocTensorOp a1 =
+ prod.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
+ AllocTensorOp a2 =
+ op.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
+ a2.getCopyMutable().assign(a1.getCopy());
+ }
// Replace consumer with fused operation. Old producer
// and consumer ops will be removed by DCE.
rewriter.replaceOp(op, fusedOp->getResults());
@@ -240,7 +253,7 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
//===---------------------------------------------------------------------===//
void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
- // TODO(springerm): enable FuseSparseMultiplyOverAdd
- patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
- ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+ patterns
+ .add<FuseSparseMultiplyOverAdd, ReshapeRewriter<tensor::ExpandShapeOp>,
+ ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
}
diff --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
new file mode 100644
index 0000000000000..4dd4fd328ce78
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt %s --tensor-copy-insertion --sparsification --cse | FileCheck %s
+
+#SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+#trait_matmul = {
+ indexing_maps = [
+ affine_map<(d0, d1, d2) -> (d1, d0)>,
+ affine_map<(d0, d1, d2) -> (d0, d2)>,
+ affine_map<(d0, d1, d2) -> (d1, d2)>
+ ],
+ iterator_types = ["reduction", "parallel", "parallel"]
+}
+
+#trait_scale = {
+ indexing_maps = [
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0, d1)>
+ ],
+ iterator_types = ["parallel", "parallel"]
+}
+
+// CHECK-LABEL: func.func @sampled_dd_unfused(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64> {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
+// CHECK: %[[VAL_7:.*]] = bufferization.alloc_tensor() copy(%[[VAL_6]]) {bufferization.escape = [false]} : tensor<8x8xf64>
+// CHECK: %[[VAL_8:.*]] = bufferization.alloc_tensor() copy(%[[VAL_6]]) {bufferization.escape = [false], memory_space = 0 : ui64} : tensor<8x8xf64>
+// CHECK: %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
+// CHECK: %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
+// CHECK: %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
+// CHECK: %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_8]] : memref<8x8xf64>
+// CHECK: %[[VAL_17:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_5]] {
+// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK: %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_5]] {
+// CHECK: %[[VAL_25:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<8x8xf64>
+// CHECK: %[[VAL_27:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_24]]] : memref<?xf64>
+// CHECK: %[[VAL_28:.*]] = scf.for %[[VAL_29:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_30:.*]] = %[[VAL_26]]) -> (f64) {
+// CHECK: %[[VAL_31:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]], %[[VAL_29]]] : memref<8x8xf64>
+// CHECK: %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_29]], %[[VAL_25]]] : memref<8x8xf64>
+// CHECK: %[[VAL_33:.*]] = arith.mulf %[[VAL_31]], %[[VAL_32]] : f64
+// CHECK: %[[VAL_34:.*]] = arith.mulf %[[VAL_33]], %[[VAL_27]] : f64
+// CHECK: %[[VAL_35:.*]] = arith.addf %[[VAL_30]], %[[VAL_34]] : f64
+// CHECK: scf.yield %[[VAL_35]] : f64
+// CHECK: }
+// CHECK: memref.store %[[VAL_24:.*]], %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<8x8xf64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_37:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<8x8xf64>
+// CHECK: return %[[VAL_37]] : tensor<8x8xf64>
+// CHECK: }
+func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
+ %arga: tensor<8x8xf64>,
+ %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {
+ // Perform dense-dense matrix matrix multiplication.
+ %1 = arith.constant dense<0.0> : tensor<8x8xf64>
+ %2 = linalg.generic #trait_matmul
+ ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>)
+ outs(%1 : tensor<8x8xf64>) {
+ ^bb0(%a: f64, %b: f64, %x: f64):
+ %p = arith.mulf %a, %b : f64
+ %q = arith.addf %x, %p : f64
+ linalg.yield %q : f64
+ } -> tensor<8x8xf64>
+ // Sample the result with elements-wise multiplication with sparse matrix.
+ %3 = linalg.generic #trait_scale
+ ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
+ outs(%1 : tensor<8x8xf64>) {
+ ^bb0(%t: f64, %s: f64, %x: f64):
+ %r = arith.mulf %t, %s : f64
+ linalg.yield %r : f64
+ } -> tensor<8x8xf64>
+ return %3 : tensor<8x8xf64>
+}
+
+// CHECK-LABEL: func.func @sparse_sampled_dd_unfused(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<8x8xf64>,
+// CHECK-SAME: %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG: %[[VAL_6:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG: %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
+// CHECK: %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64>
+// CHECK: %[[VAL_10:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
+// CHECK: %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
+// CHECK: %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_15:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_16:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK: %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
+// CHECK: %[[VAL_18:.*]] = memref.alloca(%[[VAL_6]]) : memref<?xindex>
+// CHECK: %[[VAL_19:.*]] = memref.alloca() : memref<f64>
+// CHECK: %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_5]] {
+// CHECK: %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK: memref.store %[[VAL_23]], %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK: %[[VAL_24:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_22]], %[[VAL_5]] : index
+// CHECK: %[[VAL_26:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK: scf.for %[[VAL_27:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] {
+// CHECK: %[[VAL_28:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_27]]] : memref<?xindex>
+// CHECK: memref.store %[[VAL_28]], %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK: %[[VAL_29:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]]] : memref<?xf64>
+// CHECK: %[[VAL_30:.*]] = scf.for %[[VAL_31:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_32:.*]] = %[[VAL_7]]) -> (f64) {
+// CHECK: memref.store %[[VAL_31]], %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK: %[[VAL_33:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]], %[[VAL_31]]] : memref<8x8xf64>
+// CHECK: %[[VAL_34:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_31]], %[[VAL_28]]] : memref<8x8xf64>
+// CHECK: %[[VAL_35:.*]] = arith.mulf %[[VAL_33]], %[[VAL_34]] : f64
+// CHECK: %[[VAL_36:.*]] = arith.mulf %[[VAL_35]], %[[VAL_29]] : f64
+// CHECK: %[[VAL_37:.*]] = arith.addf %[[VAL_32]], %[[VAL_36]] : f64
+// CHECK: scf.yield %[[VAL_37]] : f64
+// CHECK: }
+// CHECK: memref.store %[[VAL_30:.*]], %[[VAL_19]][] : memref<f64>
+// CHECK: sparse_tensor.lex_insert %[[VAL_10]], %[[VAL_18]], %[[VAL_19]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, memref<?xindex>, memref<f64>
+// CHECK: }
+// CHECK: }
+// CHECK: %[[VAL_39:.*]] = sparse_tensor.load %[[VAL_10]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: return %[[VAL_39]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK: }
+func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
+ %arga: tensor<8x8xf64>,
+ %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> {
+ // Perform dense-dense matrix matrix multiplication.
+ %1 = arith.constant dense<0.0> : tensor<8x8xf64>
+ %2 = linalg.generic #trait_matmul
+ ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>)
+ outs(%1 : tensor<8x8xf64>) {
+ ^bb0(%a: f64, %b: f64, %x: f64):
+ %p = arith.mulf %a, %b : f64
+ %q = arith.addf %x, %p : f64
+ linalg.yield %q : f64
+ } -> tensor<8x8xf64>
+ // Sample the result with elements-wise multiplication with sparse matrix.
+ %3 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM>
+ %4 = linalg.generic #trait_scale
+ ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
+ outs(%3 : tensor<8x8xf64, #SM>) {
+ ^bb0(%t: f64, %s: f64, %x: f64):
+ %r = arith.mulf %t, %s : f64
+ linalg.yield %r : f64
+ } -> tensor<8x8xf64, #SM>
+ return %4 : tensor<8x8xf64, #SM>
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
index fc5f8c597a0b0..3c7f741c77fcc 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
@@ -50,8 +50,8 @@ module {
// (with dense result).
//
func.func @sampled_dd(%args: tensor<8x8xf64, #SM>,
- %arga: tensor<8x8xf64>,
- %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {
+ %arga: tensor<8x8xf64>,
+ %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {
%1 = arith.constant dense<0.0> : tensor<8x8xf64>
%2 = linalg.generic #trait_sampled_dense_dense
ins(%args, %arga, %argb: tensor<8x8xf64, #SM>,
@@ -71,8 +71,8 @@ module {
// (with dense result).
//
func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
- %arga: tensor<8x8xf64>,
- %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {
+ %arga: tensor<8x8xf64>,
+ %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {
// Perform dense-dense matrix matrix multiplication.
%1 = arith.constant dense<0.0> : tensor<8x8xf64>
%2 = linalg.generic #trait_matmul
@@ -99,8 +99,8 @@ module {
// (with sparse result).
//
func.func @sparse_sampled_dd(%args: tensor<8x8xf64, #SM>,
- %arga: tensor<8x8xf64>,
- %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> {
+ %arga: tensor<8x8xf64>,
+ %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> {
%1 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM>
%2 = linalg.generic #trait_sampled_dense_dense
ins(%args, %arga, %argb: tensor<8x8xf64, #SM>,
More information about the Mlir-commits
mailing list