[Mlir-commits] [mlir] ce3d0e8 - [mlir][sparse] enable SDDMM-flavored fusion

Aart Bik llvmlistbot at llvm.org
Tue Aug 2 12:40:14 PDT 2022


Author: Aart Bik
Date: 2022-08-02T12:40:04-07:00
New Revision: ce3d0e87ac23ce4c2be8a3b99c3020f930d7ba16

URL: https://github.com/llvm/llvm-project/commit/ce3d0e87ac23ce4c2be8a3b99c3020f930d7ba16
DIFF: https://github.com/llvm/llvm-project/commit/ce3d0e87ac23ce4c2be8a3b99c3020f930d7ba16.diff

LOG: [mlir][sparse] enable SDDMM-flavored fusion

This rewriting was no longer functional after recent migration to one shot
bufferization. However, this revision makes it work again, with a CHECK test
to ensure fusion happens. Note that functionality is tested by several
integration tests.

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D130996

Added: 
    mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
    mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
index a01d7c86ada18..f0aef0f0f6386 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp
@@ -41,12 +41,17 @@ static bool isSparseTensor(OpOperand *op) {
   return false;
 }
 
-// Helper method to find zero or empty initialization.
-static bool isEmptyInit(OpOperand *op) {
+// Helper method to find zero/uninitialized allocation.
+static bool isAlloc(OpOperand *op, bool isZero) {
   Value val = op->get();
-  return matchPattern(val, m_Zero()) || matchPattern(val, m_AnyZeroFloat()) ||
-         val.getDefiningOp<InitTensorOp>() ||
-         val.getDefiningOp<AllocTensorOp>();
+  if (auto alloc = val.getDefiningOp<AllocTensorOp>()) {
+    Value copy = alloc.getCopy();
+    if (isZero)
+      return copy && (matchPattern(copy, m_Zero()) ||
+                      matchPattern(copy, m_AnyZeroFloat()));
+    return !copy;
+  }
+  return false;
 }
 
 // Helper to detect sampling operation.
@@ -140,9 +145,9 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
         !prod.getResult(0).hasOneUse())
       return failure();
     // Sampling consumer and sum of multiplication chain producer.
-    if (!isEmptyInit(op.getOutputOperand(0)) ||
-        !isEmptyInit(prod.getOutputOperand(0)) || !isSampling(op) ||
-        !isSumOfMul(prod))
+    if (!isAlloc(op.getOutputOperand(0), /*isZero=*/false) ||
+        !isAlloc(prod.getOutputOperand(0), /*isZero=*/true) ||
+        !isSampling(op) || !isSumOfMul(prod))
       return failure();
     // Modify operand structure of producer and consumer.
     Location loc = prod.getLoc();
@@ -180,6 +185,14 @@ struct FuseSparseMultiplyOverAdd : public OpRewritePattern<GenericOp> {
     mapper.map(last, rewriter.clone(*sampler, mapper)->getResult(0));
     last = rewriter.clone(*acc, mapper)->getResult(0);
     rewriter.create<linalg::YieldOp>(loc, last);
+    // Force initial value on merged allocation for dense outputs.
+    if (!getSparseTensorEncoding(op.getResult(0).getType())) {
+      AllocTensorOp a1 =
+          prod.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
+      AllocTensorOp a2 =
+          op.getOutputOperand(0)->get().getDefiningOp<AllocTensorOp>();
+      a2.getCopyMutable().assign(a1.getCopy());
+    }
     // Replace consumer with fused operation. Old producer
     // and consumer ops will be removed by DCE.
     rewriter.replaceOp(op, fusedOp->getResults());
@@ -240,7 +253,7 @@ struct ReshapeRewriter : public OpRewritePattern<ReshapeOp> {
 //===---------------------------------------------------------------------===//
 
 void mlir::populateSparseTensorRewriting(RewritePatternSet &patterns) {
-  // TODO(springerm): enable FuseSparseMultiplyOverAdd
-  patterns.add<ReshapeRewriter<tensor::ExpandShapeOp>,
-               ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
+  patterns
+      .add<FuseSparseMultiplyOverAdd, ReshapeRewriter<tensor::ExpandShapeOp>,
+           ReshapeRewriter<tensor::CollapseShapeOp>>(patterns.getContext());
 }

diff  --git a/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
new file mode 100644
index 0000000000000..4dd4fd328ce78
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/sparse_sddmm.mlir
@@ -0,0 +1,162 @@
+// RUN: mlir-opt %s  --tensor-copy-insertion --sparsification --cse | FileCheck %s
+
+#SM = #sparse_tensor.encoding<{ dimLevelType = [ "compressed", "compressed" ] }>
+
+#trait_matmul = {
+  indexing_maps = [
+    affine_map<(d0, d1, d2) -> (d1, d0)>,
+    affine_map<(d0, d1, d2) -> (d0, d2)>,
+    affine_map<(d0, d1, d2) -> (d1, d2)>
+  ],
+  iterator_types = ["reduction", "parallel", "parallel"]
+}
+
+#trait_scale = {
+  indexing_maps = [
+    affine_map<(d0, d1) -> (d0, d1)>,
+    affine_map<(d0, d1) -> (d0, d1)>,
+    affine_map<(d0, d1) -> (d0, d1)>
+  ],
+  iterator_types = ["parallel", "parallel"]
+}
+
+// CHECK-LABEL: func.func @sampled_dd_unfused(
+// CHECK-SAME:    %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME:    %[[VAL_1:.*]]: tensor<8x8xf64>,
+// CHECK-SAME:    %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64> {
+// CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[VAL_6:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
+// CHECK:         %[[VAL_7:.*]] = bufferization.alloc_tensor() copy(%[[VAL_6]]) {bufferization.escape = [false]} : tensor<8x8xf64>
+// CHECK:         %[[VAL_8:.*]] = bufferization.alloc_tensor() copy(%[[VAL_6]]) {bufferization.escape = [false], memory_space = 0 : ui64} : tensor<8x8xf64>
+// CHECK:         %[[VAL_9:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
+// CHECK:         %[[VAL_10:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
+// CHECK:         %[[VAL_11:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:         %[[VAL_12:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:         %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:         %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:         %[[VAL_15:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
+// CHECK:         %[[VAL_16:.*]] = bufferization.to_memref %[[VAL_8]] : memref<8x8xf64>
+// CHECK:         %[[VAL_17:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:         %[[VAL_18:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:         scf.for %[[VAL_19:.*]] = %[[VAL_17]] to %[[VAL_18]] step %[[VAL_5]] {
+// CHECK:           %[[VAL_20:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:           %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_19]]] : memref<?xindex>
+// CHECK:           %[[VAL_22:.*]] = arith.addi %[[VAL_19]], %[[VAL_5]] : index
+// CHECK:           %[[VAL_23:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_24:.*]] = %[[VAL_21]] to %[[VAL_23]] step %[[VAL_5]] {
+// CHECK:             %[[VAL_25:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_24]]] : memref<?xindex>
+// CHECK:             %[[VAL_26:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<8x8xf64>
+// CHECK:             %[[VAL_27:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_24]]] : memref<?xf64>
+// CHECK:             %[[VAL_28:.*]] = scf.for %[[VAL_29:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_30:.*]] = %[[VAL_26]]) -> (f64) {
+// CHECK:               %[[VAL_31:.*]] = memref.load %[[VAL_9]]{{\[}}%[[VAL_20]], %[[VAL_29]]] : memref<8x8xf64>
+// CHECK:               %[[VAL_32:.*]] = memref.load %[[VAL_10]]{{\[}}%[[VAL_29]], %[[VAL_25]]] : memref<8x8xf64>
+// CHECK:               %[[VAL_33:.*]] = arith.mulf %[[VAL_31]], %[[VAL_32]] : f64
+// CHECK:               %[[VAL_34:.*]] = arith.mulf %[[VAL_33]], %[[VAL_27]] : f64
+// CHECK:               %[[VAL_35:.*]] = arith.addf %[[VAL_30]], %[[VAL_34]] : f64
+// CHECK:               scf.yield %[[VAL_35]] : f64
+// CHECK:             }
+// CHECK:             memref.store %[[VAL_24:.*]], %[[VAL_16]]{{\[}}%[[VAL_20]], %[[VAL_25]]] : memref<8x8xf64>
+// CHECK:           }
+// CHECK:         }
+// CHECK:         %[[VAL_37:.*]] = bufferization.to_tensor %[[VAL_16]] : memref<8x8xf64>
+// CHECK:         return %[[VAL_37]] : tensor<8x8xf64>
+// CHECK:       }
+func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
+                              %arga: tensor<8x8xf64>,
+                              %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {
+  // Perform dense-dense matrix matrix multiplication.
+  %1 = arith.constant dense<0.0> : tensor<8x8xf64>
+  %2 = linalg.generic #trait_matmul
+    ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>)
+    outs(%1 : tensor<8x8xf64>) {
+      ^bb0(%a: f64, %b: f64, %x: f64):
+        %p = arith.mulf %a, %b : f64
+        %q = arith.addf %x, %p : f64
+        linalg.yield %q : f64
+  } -> tensor<8x8xf64>
+  // Sample the result with elements-wise multiplication with sparse matrix.
+  %3 = linalg.generic #trait_scale
+    ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
+    outs(%1 : tensor<8x8xf64>) {
+      ^bb0(%t: f64, %s: f64, %x: f64):
+        %r = arith.mulf %t, %s : f64
+        linalg.yield %r : f64
+  } -> tensor<8x8xf64>
+  return %3 : tensor<8x8xf64>
+}
+
+// CHECK-LABEL: func.func @sparse_sampled_dd_unfused(
+// CHECK-SAME:    %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>,
+// CHECK-SAME:    %[[VAL_1:.*]]: tensor<8x8xf64>,
+// CHECK-SAME:    %[[VAL_2:.*]]: tensor<8x8xf64>) -> tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> {
+// CHECK-DAG:     %[[VAL_3:.*]] = arith.constant 8 : index
+// CHECK-DAG:     %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK-DAG:     %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK-DAG:     %[[VAL_6:.*]] = arith.constant 2 : index
+// CHECK-DAG:     %[[VAL_7:.*]] = arith.constant 0.000000e+00 : f64
+// CHECK-DAG:     %[[VAL_8:.*]] = arith.constant dense<0.000000e+00> : tensor<8x8xf64>
+// CHECK:         %[[VAL_9:.*]] = bufferization.alloc_tensor() copy(%[[VAL_8]]) {bufferization.escape = [false]} : tensor<8x8xf64>
+// CHECK:         %[[VAL_10:.*]] = bufferization.alloc_tensor() {bufferization.escape = [false]} : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:         %[[VAL_11:.*]] = bufferization.to_memref %[[VAL_1]] : memref<8x8xf64>
+// CHECK:         %[[VAL_12:.*]] = bufferization.to_memref %[[VAL_2]] : memref<8x8xf64>
+// CHECK:         %[[VAL_13:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:         %[[VAL_14:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_4]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:         %[[VAL_15:.*]] = sparse_tensor.pointers %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:         %[[VAL_16:.*]] = sparse_tensor.indices %[[VAL_0]], %[[VAL_5]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xindex>
+// CHECK:         %[[VAL_17:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>> to memref<?xf64>
+// CHECK:         %[[VAL_18:.*]] = memref.alloca(%[[VAL_6]]) : memref<?xindex>
+// CHECK:         %[[VAL_19:.*]] = memref.alloca() : memref<f64>
+// CHECK:         %[[VAL_20:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:         %[[VAL_21:.*]] = memref.load %[[VAL_13]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:         scf.for %[[VAL_22:.*]] = %[[VAL_20]] to %[[VAL_21]] step %[[VAL_5]] {
+// CHECK:           %[[VAL_23:.*]] = memref.load %[[VAL_14]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK:           memref.store %[[VAL_23]], %[[VAL_18]]{{\[}}%[[VAL_4]]] : memref<?xindex>
+// CHECK:           %[[VAL_24:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_22]]] : memref<?xindex>
+// CHECK:           %[[VAL_25:.*]] = arith.addi %[[VAL_22]], %[[VAL_5]] : index
+// CHECK:           %[[VAL_26:.*]] = memref.load %[[VAL_15]]{{\[}}%[[VAL_25]]] : memref<?xindex>
+// CHECK:           scf.for %[[VAL_27:.*]] = %[[VAL_24]] to %[[VAL_26]] step %[[VAL_5]] {
+// CHECK:             %[[VAL_28:.*]] = memref.load %[[VAL_16]]{{\[}}%[[VAL_27]]] : memref<?xindex>
+// CHECK:             memref.store %[[VAL_28]], %[[VAL_18]]{{\[}}%[[VAL_5]]] : memref<?xindex>
+// CHECK:             %[[VAL_29:.*]] = memref.load %[[VAL_17]]{{\[}}%[[VAL_27]]] : memref<?xf64>
+// CHECK:             %[[VAL_30:.*]] = scf.for %[[VAL_31:.*]] = %[[VAL_4]] to %[[VAL_3]] step %[[VAL_5]] iter_args(%[[VAL_32:.*]] = %[[VAL_7]]) -> (f64) {
+// CHECK:               memref.store %[[VAL_31]], %[[VAL_18]]{{\[}}%[[VAL_6]]] : memref<?xindex>
+// CHECK:               %[[VAL_33:.*]] = memref.load %[[VAL_11]]{{\[}}%[[VAL_23]], %[[VAL_31]]] : memref<8x8xf64>
+// CHECK:               %[[VAL_34:.*]] = memref.load %[[VAL_12]]{{\[}}%[[VAL_31]], %[[VAL_28]]] : memref<8x8xf64>
+// CHECK:               %[[VAL_35:.*]] = arith.mulf %[[VAL_33]], %[[VAL_34]] : f64
+// CHECK:               %[[VAL_36:.*]] = arith.mulf %[[VAL_35]], %[[VAL_29]] : f64
+// CHECK:               %[[VAL_37:.*]] = arith.addf %[[VAL_32]], %[[VAL_36]] : f64
+// CHECK:               scf.yield %[[VAL_37]] : f64
+// CHECK:             }
+// CHECK:             memref.store %[[VAL_30:.*]], %[[VAL_19]][] : memref<f64>
+// CHECK:             sparse_tensor.lex_insert %[[VAL_10]], %[[VAL_18]], %[[VAL_19]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>, memref<?xindex>, memref<f64>
+// CHECK:           }
+// CHECK:         }
+// CHECK:         %[[VAL_39:.*]] = sparse_tensor.load %[[VAL_10]] hasInserts : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:         return %[[VAL_39]] : tensor<8x8xf64, #sparse_tensor.encoding<{{.*}}>>
+// CHECK:       }
+func.func @sparse_sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
+                                     %arga: tensor<8x8xf64>,
+                                     %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> {
+  // Perform dense-dense matrix matrix multiplication.
+  %1 = arith.constant dense<0.0> : tensor<8x8xf64>
+  %2 = linalg.generic #trait_matmul
+    ins(%arga, %argb : tensor<8x8xf64>, tensor<8x8xf64>)
+    outs(%1 : tensor<8x8xf64>) {
+      ^bb0(%a: f64, %b: f64, %x: f64):
+        %p = arith.mulf %a, %b : f64
+        %q = arith.addf %x, %p : f64
+        linalg.yield %q : f64
+  } -> tensor<8x8xf64>
+  // Sample the result with elements-wise multiplication with sparse matrix.
+  %3 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM>
+  %4 = linalg.generic #trait_scale
+    ins(%2, %args : tensor<8x8xf64>, tensor<8x8xf64, #SM>)
+    outs(%3 : tensor<8x8xf64, #SM>) {
+      ^bb0(%t: f64, %s: f64, %x: f64):
+        %r = arith.mulf %t, %s : f64
+        linalg.yield %r : f64
+  } -> tensor<8x8xf64, #SM>
+  return %4 : tensor<8x8xf64, #SM>
+}

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
index fc5f8c597a0b0..3c7f741c77fcc 100755
--- a/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_sampled_mm_fusion.mlir
@@ -50,8 +50,8 @@ module {
   // (with dense result).
   //
   func.func @sampled_dd(%args: tensor<8x8xf64, #SM>,
-                   %arga: tensor<8x8xf64>,
-                   %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {
+                        %arga: tensor<8x8xf64>,
+                        %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {
     %1 = arith.constant dense<0.0> : tensor<8x8xf64>
     %2 = linalg.generic #trait_sampled_dense_dense
       ins(%args, %arga, %argb: tensor<8x8xf64, #SM>,
@@ -71,8 +71,8 @@ module {
   // (with dense result).
   //
   func.func @sampled_dd_unfused(%args: tensor<8x8xf64, #SM>,
-                           %arga: tensor<8x8xf64>,
-                           %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {
+                                %arga: tensor<8x8xf64>,
+                                %argb: tensor<8x8xf64>) -> tensor<8x8xf64> {
     // Perform dense-dense matrix matrix multiplication.
     %1 = arith.constant dense<0.0> : tensor<8x8xf64>
     %2 = linalg.generic #trait_matmul
@@ -99,8 +99,8 @@ module {
   // (with sparse result).
   //
   func.func @sparse_sampled_dd(%args: tensor<8x8xf64, #SM>,
-                          %arga: tensor<8x8xf64>,
-                          %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> {
+                               %arga: tensor<8x8xf64>,
+                               %argb: tensor<8x8xf64>) -> tensor<8x8xf64, #SM> {
     %1 = bufferization.alloc_tensor() : tensor<8x8xf64, #SM>
     %2 = linalg.generic #trait_sampled_dense_dense
       ins(%args, %arga, %argb: tensor<8x8xf64, #SM>,


        


More information about the Mlir-commits mailing list