[Mlir-commits] [mlir] f14c8eb - [mlir][sparse][gpu] refine SDDMM pattern for cuSPARSE

Aart Bik llvmlistbot at llvm.org
Wed Jun 21 18:32:03 PDT 2023


Author: Aart Bik
Date: 2023-06-21T18:31:55-07:00
New Revision: f14c8eb595d0ba84dde09a0d5860f117c11e9ca1

URL: https://github.com/llvm/llvm-project/commit/f14c8eb595d0ba84dde09a0d5860f117c11e9ca1
DIFF: https://github.com/llvm/llvm-project/commit/f14c8eb595d0ba84dde09a0d5860f117c11e9ca1.diff

LOG: [mlir][sparse][gpu] refine SDDMM pattern for cuSPARSE

Old pattern was missing some cases (e.g. swapping the arguments)
but it also allowed too many cases (e.g. non-empty "absent" or
different arguments for add/mul). This fixes the issues.

Reviewed By: K-Wu

Differential Revision: https://reviews.llvm.org/D153487

Added: 
    

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
    mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 86c429adf1d04..073aa8419ea53 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -301,12 +301,25 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
 // Library helper methods.
 //===----------------------------------------------------------------------===//
 
-/// Helper to detect a * b.
-static bool matchMulOfArgs(linalg::GenericOp op, Value val) {
+/// Helper to detect a + b with arguments taken from given block.
+static bool matchAddOfArgs(Block *block, Value val) {
   if (auto *def = val.getDefiningOp()) {
-    if (isa<arith::MulFOp>(def) || isa<arith::MulIOp>(def)) {
-      Value a = op.getBlock()->getArguments()[0];
-      Value b = op.getBlock()->getArguments()[1];
+    if (isa<arith::AddFOp, arith::AddIOp>(def)) {
+      Value a = block->getArguments()[0];
+      Value b = block->getArguments()[1];
+      return (def->getOperand(0) == a && def->getOperand(1) == b) ||
+             (def->getOperand(0) == b && def->getOperand(1) == a);
+    }
+  }
+  return false;
+}
+
+/// Helper to detect a * b with arguments taken from given block.
+static bool matchMulOfArgs(Block *block, Value val) {
+  if (auto *def = val.getDefiningOp()) {
+    if (isa<arith::MulFOp, arith::MulIOp>(def)) {
+      Value a = block->getArguments()[0];
+      Value b = block->getArguments()[1];
       return (def->getOperand(0) == a && def->getOperand(1) == b) ||
              (def->getOperand(0) == b && def->getOperand(1) == a);
     }
@@ -318,67 +331,47 @@ static bool matchMulOfArgs(linalg::GenericOp op, Value val) {
 static bool matchSumOfMultOfArgs(linalg::GenericOp op) {
   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
   if (auto *def = yieldOp.getOperand(0).getDefiningOp()) {
-    if (isa<arith::AddFOp>(def) || isa<arith::AddIOp>(def)) {
+    if (isa<arith::AddFOp, arith::AddIOp>(def)) {
       Value x = op.getBlock()->getArguments()[2];
       return (def->getOperand(0) == x &&
-              matchMulOfArgs(op, def->getOperand(1))) ||
+              matchMulOfArgs(op.getBlock(), def->getOperand(1))) ||
              (def->getOperand(1) == x &&
-              matchMulOfArgs(op, def->getOperand(0)));
+              matchMulOfArgs(op.getBlock(), def->getOperand(0)));
     }
   }
   return false;
 }
 
-// Helper to detect c = c \spy (a * b)
+// Helper to detect c += spy(s) x (a * b)
 static bool matchSumReductionOfMulUnary(linalg::GenericOp op) {
   auto yieldOp = cast<linalg::YieldOp>(op.getRegion().front().getTerminator());
-  auto def = yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>();
-  if (!def)
-    return false;
+  // The linalg yields a custom reduce result.
   Value s_out = op.getBlock()->getArguments()[2];
-  for (auto *use : s_out.getUsers()) {
-    if (!(isa<sparse_tensor::ReduceOp>(use) ||
-          isa<sparse_tensor::UnaryOp>(use)))
-      return false;
-    // The sparse matrix should be specified as the pattern in the two
-    // operators.
-    if (s_out != use->getOperand(0))
+  if (auto redOp =
+          yieldOp.getOperand(0).getDefiningOp<sparse_tensor::ReduceOp>()) {
+    // The reduce consumes the output.
+    Value other;
+    if (s_out == redOp->getOperand(0))
+      other = redOp->getOperand(1);
+    else if (s_out == redOp->getOperand(1))
+      other = redOp->getOperand(0);
+    else
       return false;
-
-    // the above logic makes sure the pattern involves reduction and unary,
-    // i.e.,
-    // %1 = sparse_tensor.unary
-    // %2 = sparse_tensor.reduce
-    // we need to make sure %1 produces A*B and %2 uses summation as the
-    // reduction operator.
-    if (isa<sparse_tensor::ReduceOp>(use)) {
-      auto reduceSpOp = cast<sparse_tensor::ReduceOp>(use);
-      auto yieldSpOp = cast<sparse_tensor::YieldOp>(
-          reduceSpOp.getRegion().front().getTerminator());
-      auto *reduce = yieldSpOp.getOperand(0).getDefiningOp();
-      if (!isa_and_nonnull<arith::AddFOp>(reduce) &&
-          !isa_and_nonnull<arith::AddIOp>(reduce))
-        return false;
-    }
-    if (isa<sparse_tensor::UnaryOp>(use)) {
-      auto unarySpOp = cast<sparse_tensor::UnaryOp>(use);
-      auto yieldSpOp = cast<sparse_tensor::YieldOp>(
-          unarySpOp.getRegion(0).front().getTerminator());
-      auto *unary = yieldSpOp.getOperand(0).getDefiningOp();
-      if (!isa_and_nonnull<arith::MulIOp>(unary) &&
-          !isa_and_nonnull<arith::MulFOp>(unary))
+    // The reduce op also consumes an unary which also consumes the output
+    // and does not define an absent value.
+    if (auto unOp = other.getDefiningOp<sparse_tensor::UnaryOp>()) {
+      if (s_out != unOp->getOperand(0) || !unOp.getAbsentRegion().empty())
         return false;
-
-      // we also need to make sure the unary operation is used by the reduction
-      // operation.
-      for (auto *useUnary : unarySpOp->getUsers()) {
-        if (!isa<sparse_tensor::ReduceOp>(useUnary)) {
-          return false;
-        }
-      }
+      // And the bodies are as expected.
+      auto yieldUn = cast<sparse_tensor::YieldOp>(
+          unOp.getRegion(0).front().getTerminator());
+      auto yieldRed = cast<sparse_tensor::YieldOp>(
+          redOp.getRegion().front().getTerminator());
+      return matchMulOfArgs(op.getBlock(), yieldUn.getOperand(0)) &&
+             matchAddOfArgs(&redOp.getRegion().front(), yieldRed.getOperand(0));
     }
   }
-  return true;
+  return false;
 }
 
 /// Test for sorted COO with suitable data and coordinates types.
@@ -679,37 +672,24 @@ static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
   return success();
 }
 
-// TODO: identify alpha and beta and pass them to the CUDA calls
 /// Match and rewrite SDDMM kernel.
 static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
                                   linalg::GenericOp op, bool enableRT) {
-
-  // For now, this pass reuses C and copies the result non-zero elements to
-  // overwrite C's.
-  // As an ad hoc solution, this pass also assumes the linalg takes a,b,c as
-  // input argument, and c as the output. It recognizes this pattern and rewrite
-  // it.
-
   Location loc = op.getLoc();
   Value a = op.getOperand(0);
   Value b = op.getOperand(1);
   Value c = op.getOperand(2);
-
   SmallVector<Value> tokens;
 
-  // Only admissible sparse matrix format and dense matrices.
+  // Only admissible sparse matrix format and dense matrices, no COO.
   bool isCOO = false;
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType bTp = getSparseTensorType(b);
   SparseTensorType cTp = getSparseTensorType(c);
-
   if (!areAdmissibleTypes(cTp, bTp, aTp, enableRT, false, isCOO))
     return failure();
-
-  // cusparse currently does not support COO in its SDDMM kernel.
-  if (isCOO) {
+  if (isCOO)
     return failure();
-  }
 
   // The SDDMM does the in-place operation.
   // Start sparse kernel and copy data from host to device.
@@ -798,8 +778,8 @@ static LogicalResult rewriteSDDMM(PatternRewriter &rewriter,
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
 
+  // Done.
   rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
-
   return success();
 }
 
@@ -933,6 +913,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
     bindDims(getContext(), i, j, k);
 
     // TODO: more robust patterns, tranposed versions, more kernels...
+    // TODO: identify alpha and beta and pass them to the CUDA calls
 
     // Recognize a SpMV kernel.
     if (numLoops == 2 && numTensors == 3 &&

diff  --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir
index 0ac9b4d76b38c..2cd9e2847a623 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_sampled_matmul_lib.mlir
@@ -1,5 +1,4 @@
-// RUN: mlir-opt %s --linalg-generalize-named-ops \
-// RUN:             --sparsification="enable-gpu-libgen" | FileCheck %s
+// RUN: mlir-opt %s --sparsification="enable-gpu-libgen" | FileCheck %s
 
 #trait_sampled_dense_dense = {
   indexing_maps = [
@@ -22,8 +21,6 @@
 
 #CSR = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>
 
-module {
-
 // CHECK-LABEL:   func.func @sparse_sampled_dd(
 // CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<8x8xf64, #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>>,
 // CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<8x8xf64>,
@@ -82,7 +79,7 @@ module {
 // A kernel that computes a direct sampled matrix matrix multiplication
 // (with sparse result).
 // Compute SDDMM C = C\spy AB
-// 
+//
 func.func @sparse_sampled_dd(%argS: tensor<8x8xf64, #CSR>,
                                %argA: tensor<8x8xf64>,
                                %argB: tensor<8x8xf64>) -> tensor<8x8xf64, #CSR> {
@@ -106,6 +103,4 @@ func.func @sparse_sampled_dd(%argS: tensor<8x8xf64, #CSR>,
            linalg.yield %r : f64
     } -> tensor<8x8xf64, #CSR>
     return %result : tensor<8x8xf64, #CSR>
-  }
-
 }


        


More information about the Mlir-commits mailing list