[Mlir-commits] [mlir] b75d6a4 - [mlir][sparse][gpu] recognize SpMM cuSparse during sparsification

Aart Bik llvmlistbot at llvm.org
Fri May 19 17:23:09 PDT 2023


Author: Aart Bik
Date: 2023-05-19T17:22:59-07:00
New Revision: b75d6a40f15efdc5ba593cadd3460a6fbbb28af0

URL: https://github.com/llvm/llvm-project/commit/b75d6a40f15efdc5ba593cadd3460a6fbbb28af0
DIFF: https://github.com/llvm/llvm-project/commit/b75d6a40f15efdc5ba593cadd3460a6fbbb28af0.diff

LOG: [mlir][sparse][gpu] recognize SpMM cuSparse during sparsification

Reviewed By: Peiming

Differential Revision: https://reviews.llvm.org/D150715

Added: 
    mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index ac199e02d95d6..cadbe86942274 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -347,6 +347,19 @@ static bool isAdmissibleCSR(SparseTensorType &aTp) {
           aTp.getCrdWidth() == 64);
 }
 
+/// Test for admissible types on operands (with output parameter `isCOO`).
+static bool areAdmissibleTypes(SparseTensorType aTp, SparseTensorType bTp,
+                               SparseTensorType cTp, bool enableRT,
+                               bool &isCOO) {
+  if (bTp.hasEncoding() || cTp.hasEncoding())
+    return false;
+  if (isAdmissibleCOO(aTp)) {
+    isCOO = true;
+    return enableRT; // TODO: CreateCooAoSOp was deprecated, find another way
+  }
+  return isAdmissibleCSR(aTp);
+}
+
 /// Generates the first positions/coordinates of a sparse matrix.
 static Value genFirstPosOrCrds(OpBuilder &builder, Location loc, Value a,
                                bool isCOO, bool enableRT) {
@@ -371,17 +384,17 @@ static Value genSecondCrds(OpBuilder &builder, Location loc, Value a,
 /// Generates the sparse matrix multiplication.
 static Operation *genSpMat(OpBuilder &builder, Location loc, Type handleTp,
                            Type tokenTp, Value token, Value szY, Value szX,
-                           Value nnzA, Value rowA, Value colA, Value valA,
+                           Value nseA, Value rowA, Value colA, Value valA,
                            bool isCOO, bool enableRT) {
   if (isCOO) {
     // Library uses SoA COO, direct IR uses AoS COO.
     if (enableRT)
       return builder.create<gpu::CreateCooOp>(loc, handleTp, tokenTp, token,
-                                              szY, szX, nnzA, rowA, colA, valA);
+                                              szY, szX, nseA, rowA, colA, valA);
     llvm_unreachable("gpu::CreateCooAoSOp is deprecated");
   }
   return builder.create<gpu::CreateCsrOp>(loc, handleTp, tokenTp, token, szY,
-                                          szX, nnzA, rowA, colA, valA);
+                                          szX, nseA, rowA, colA, valA);
 }
 
 /// Match and rewrite SpMV kernel.
@@ -393,29 +406,19 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
   Value y = op.getOperand(2); // we have y = Ax
   SmallVector<Value> tokens;
 
-  // Only admissible sparse matrix format and dense vectors for now.
+  // Only admissible sparse matrix format and dense vectors.
   bool isCOO = false;
   SparseTensorType aTp = getSparseTensorType(a);
   SparseTensorType xTp = getSparseTensorType(x);
   SparseTensorType yTp = getSparseTensorType(y);
-  if (xTp.hasEncoding() || yTp.hasEncoding())
-    return failure();
-  if (isAdmissibleCOO(aTp)) {
-    isCOO = true;
-    // TODO: CreateCooAoSOp was deprecated, find another way
-    if (!enableRT)
-      return failure();
-  } else if (isAdmissibleCSR(aTp)) {
-    isCOO = false;
-  } else {
+  if (!areAdmissibleTypes(aTp, xTp, yTp, enableRT, isCOO))
     return failure();
-  }
 
   // Start sparse kernel and copy data from host to device.
   //   a : memR/memC/memV -> rowA,colA,valA
   //   x : memX           -> vecX
   //   y : memY           -> vecY
-  Value nnzA = rewriter.create<NumberOfEntriesOp>(loc, a);
+  Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
   Value szY = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
   Value szX = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
   Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
@@ -441,7 +444,7 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
   Value handle = env.getResult(0);
   token = env.getAsyncToken();
   Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szY,
-                               szX, nnzA, rowA, colA, valA, isCOO, enableRT);
+                               szX, nseA, rowA, colA, valA, isCOO, enableRT);
   Value spMatA = spGenA->getResult(0);
   token = spGenA->getResult(1);
   auto dvecX = rewriter.create<gpu::CreateDnVecOp>(loc, handleTp, tokenTp,
@@ -500,7 +503,105 @@ static LogicalResult rewriteSpMV(PatternRewriter &rewriter,
 /// Match and rewrite SpMM kernel.
 static LogicalResult rewriteSpMM(PatternRewriter &rewriter,
                                  linalg::GenericOp op, bool enableRT) {
-  return failure(); // TODO: implement
+  Location loc = op.getLoc();
+  Value a = op.getOperand(0);
+  Value b = op.getOperand(1);
+  Value c = op.getOperand(2); // we have C = AB
+  SmallVector<Value> tokens;
+
+  // Only admissible sparse matrix format and dense matrices.
+  bool isCOO = false;
+  SparseTensorType aTp = getSparseTensorType(a);
+  SparseTensorType bTp = getSparseTensorType(b);
+  SparseTensorType cTp = getSparseTensorType(c);
+  if (!areAdmissibleTypes(aTp, bTp, cTp, enableRT, isCOO))
+    return failure();
+
+  // Start sparse kernel and copy data from host to device.
+  //   a : memR/memC/memV -> rowA,colA,valA
+  //   b : bufB           -> matA
+  //   c : bufC           -> matC
+  Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
+  Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
+  Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
+  Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
+  Value memR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
+  Value memC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
+  Value memV = genToValues(rewriter, loc, a);
+  Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
+  Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
+  Value valA = genAllocCopy(rewriter, loc, memV, tokens);
+  Value bufB = genTensorToMemref(rewriter, loc, b);
+  Value matB = genAllocCopy(rewriter, loc, bufB, tokens);
+  Value bufC = genTensorToMemref(rewriter, loc, c);
+  Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
+  genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
+
+  // Create sparse environment and sparse matrix/dense matrix handles.
+  Type indexTp = rewriter.getIndexType();
+  Type handleTp = rewriter.getType<gpu::SparseHandleType>();
+  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
+  Value token = genFirstWait(rewriter, loc);
+  auto env =
+      rewriter.create<gpu::CreateSparseEnvOp>(loc, handleTp, tokenTp, token);
+  Value handle = env.getResult(0);
+  token = env.getAsyncToken();
+  Operation *spGenA = genSpMat(rewriter, loc, handleTp, tokenTp, token, szm,
+                               szk, nseA, rowA, colA, valA, isCOO, enableRT);
+  Value spMatA = spGenA->getResult(0);
+  token = spGenA->getResult(1);
+  auto dmatB = rewriter.create<gpu::CreateDnMatOp>(loc, handleTp, tokenTp,
+                                                   token, szk, szn, matB);
+  Value dnB = dmatB.getResult(0);
+  token = dmatB.getAsyncToken();
+  auto dmatC = rewriter.create<gpu::CreateDnMatOp>(loc, handleTp, tokenTp,
+                                                   token, szm, szn, matC);
+  Value dnC = dmatC.getResult(0);
+  token = dmatC.getAsyncToken();
+
+  // Precompute buffersize for SpMM.
+  auto bufferComp = rewriter.create<gpu::SpMMBufferSizeOp>(
+      loc, indexTp, tokenTp, token, handle, spMatA, dnB, dnC);
+  Value bufferSz = bufferComp.getResult(0);
+  token = bufferComp.getAsyncToken();
+  auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
+  Value buffer = buf.getResult(0);
+  token = buf.getAsyncToken();
+
+  // Perform the SpMM.
+  auto spmmComp = rewriter.create<gpu::SpMMOp>(loc, tokenTp, token, handle,
+                                               spMatA, dnB, dnC, buffer);
+  token = spmmComp.getAsyncToken();
+
+  // Copy data back to host and free all the resoures.
+  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroyDnMatOp>(loc, tokenTp, token, dnB)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroyDnMatOp>(loc, tokenTp, token, dnC)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroySparseEnvOp>(loc, tokenTp, token, handle)
+              .getAsyncToken();
+  tokens.push_back(token);
+  genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
+  token = genFirstWait(rewriter, loc);
+  token = genCopyMemRef(rewriter, loc, bufC, matC, token);
+  token = genDeallocMemRef(rewriter, loc, rowA, token);
+  if (colA)
+    token = genDeallocMemRef(rewriter, loc, colA, token);
+  token = genDeallocMemRef(rewriter, loc, valA, token);
+  token = genDeallocMemRef(rewriter, loc, buffer, token);
+  token = genDeallocMemRef(rewriter, loc, matB, token);
+  token = genDeallocMemRef(rewriter, loc, matC, token);
+  tokens.push_back(token);
+  genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
+
+  // Done.
+  rewriter.replaceOp(op, op.getDpsInitOperand(0)->get());
+  return success();
 }
 
 //===----------------------------------------------------------------------===//
@@ -610,7 +711,7 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
 //===----------------------------------------------------------------------===//
 
 /// Proof-of-concept rewriter. This rule recognizes certain math kernels
-/// and replaces these with corresponding calls into the sparse library.
+/// and replaces these with corresponding calls into a sparse library.
 struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
   using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
 

diff  --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir
new file mode 100644
index 0000000000000..71f3bced37882
--- /dev/null
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib.mlir
@@ -0,0 +1,76 @@
+// RUN: mlir-opt %s --linalg-generalize-named-ops \
+// RUN:             --sparsification="enable-gpu-libgen" | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{ lvlTypes = [ "dense", "compressed" ] }>
+
+//
+// Compute matrix matrix C = AB
+//
+// CHECK-LABEL:   func.func @matmul(
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>,
+// CHECK-SAME:      %[[VAL_1:.*]]: tensor<?x?xf64>,
+// CHECK-SAME:      %[[VAL_2:.*]]: tensor<?x?xf64>) -> tensor<?x?xf64> {
+// CHECK-DAG:       %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK-DAG:       %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK-DAG:       %[[VAL_5:.*]] = sparse_tensor.number_of_entries %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_6:.*]] = tensor.dim %[[VAL_0]], %[[VAL_3]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x?xf64>
+// CHECK-DAG:       %[[VAL_8:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>>
+// CHECK-DAG:       %[[VAL_9:.*]] = sparse_tensor.positions %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex>
+// CHECK-DAG:       %[[VAL_10:.*]] = sparse_tensor.coordinates %[[VAL_0]] {level = 1 : index} : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xindex, strided<[?], offset: ?>>
+// CHECK:           %[[VAL_11:.*]] = sparse_tensor.values %[[VAL_0]] : tensor<?x?xf64, #sparse_tensor.encoding<{{{.*}}}>> to memref<?xf64>
+// CHECK:           %[[VAL_12:.*]] = gpu.wait async
+// CHECK:           %[[VAL_13:.*]] = memref.dim %[[VAL_9]], %[[VAL_3]] : memref<?xindex>
+// CHECK:           %[[VAL_14:.*]], %[[VAL_15:.*]] = gpu.alloc async {{\[}}%[[VAL_12]]] (%[[VAL_13]]) : memref<?xindex>
+// CHECK:           %[[VAL_16:.*]] = gpu.memcpy async {{\[}}%[[VAL_15]]] %[[VAL_14]], %[[VAL_9]] : memref<?xindex>, memref<?xindex>
+// CHECK:           %[[VAL_17:.*]] = gpu.wait async
+// CHECK:           %[[VAL_18:.*]] = memref.dim %[[VAL_10]], %[[VAL_3]] : memref<?xindex, strided<[?], offset: ?>>
+// CHECK:           %[[VAL_19:.*]], %[[VAL_20:.*]] = gpu.alloc async {{\[}}%[[VAL_17]]] (%[[VAL_18]]) : memref<?xindex>
+// CHECK:           %[[VAL_21:.*]] = gpu.memcpy async {{\[}}%[[VAL_20]]] %[[VAL_19]], %[[VAL_10]] : memref<?xindex>, memref<?xindex, strided<[?], offset: ?>>
+// CHECK:           %[[VAL_22:.*]] = gpu.wait async
+// CHECK:           %[[VAL_23:.*]] = memref.dim %[[VAL_11]], %[[VAL_3]] : memref<?xf64>
+// CHECK:           %[[VAL_24:.*]], %[[VAL_25:.*]] = gpu.alloc async {{\[}}%[[VAL_22]]] (%[[VAL_23]]) : memref<?xf64>
+// CHECK:           %[[VAL_26:.*]] = gpu.memcpy async {{\[}}%[[VAL_25]]] %[[VAL_24]], %[[VAL_11]] : memref<?xf64>, memref<?xf64>
+// CHECK:           %[[VAL_27:.*]] = bufferization.to_memref %[[VAL_1]] : memref<?x?xf64>
+// CHECK:           %[[VAL_28:.*]] = gpu.wait async
+// CHECK:           %[[VAL_29:.*]] = memref.dim %[[VAL_27]], %[[VAL_3]] : memref<?x?xf64>
+// CHECK:           %[[VAL_30:.*]] = memref.dim %[[VAL_27]], %[[VAL_4]] : memref<?x?xf64>
+// CHECK:           %[[VAL_31:.*]], %[[VAL_32:.*]] = gpu.alloc async {{\[}}%[[VAL_28]]] (%[[VAL_29]], %[[VAL_30]]) : memref<?x?xf64>
+// CHECK:           %[[VAL_33:.*]] = gpu.memcpy async {{\[}}%[[VAL_32]]] %[[VAL_31]], %[[VAL_27]] : memref<?x?xf64>, memref<?x?xf64>
+// CHECK:           %[[VAL_34:.*]] = bufferization.to_memref %[[VAL_2]] : memref<?x?xf64>
+// CHECK:           %[[VAL_35:.*]] = gpu.wait async
+// CHECK:           %[[VAL_36:.*]] = memref.dim %[[VAL_34]], %[[VAL_3]] : memref<?x?xf64>
+// CHECK:           %[[VAL_37:.*]] = memref.dim %[[VAL_34]], %[[VAL_4]] : memref<?x?xf64>
+// CHECK:           %[[VAL_38:.*]], %[[VAL_39:.*]] = gpu.alloc async {{\[}}%[[VAL_35]]] (%[[VAL_36]], %[[VAL_37]]) : memref<?x?xf64>
+// CHECK:           %[[VAL_40:.*]] = gpu.memcpy async {{\[}}%[[VAL_39]]] %[[VAL_38]], %[[VAL_34]] : memref<?x?xf64>, memref<?x?xf64>
+// CHECK:           gpu.wait {{\[}}%[[VAL_16]], %[[VAL_21]], %[[VAL_26]], %[[VAL_33]], %[[VAL_40]]]
+// CHECK:           %[[VAL_41:.*]] = gpu.wait async
+// CHECK:           %[[VAL_42:.*]], %[[VAL_43:.*]] = gpu.create_sparse_env async {{\[}}%[[VAL_41]]]
+// CHECK:           %[[VAL_44:.*]], %[[VAL_45:.*]] = gpu.create_csr async {{\[}}%[[VAL_43]]] %[[VAL_6]], %[[VAL_8]], %[[VAL_5]], %[[VAL_14]], %[[VAL_19]], %[[VAL_24]] : memref<?xindex>, memref<?xindex>, memref<?xf64>
+// CHECK:           %[[VAL_46:.*]], %[[VAL_47:.*]] = gpu.create_dn_mat async {{\[}}%[[VAL_45]]] %[[VAL_8]], %[[VAL_7]], %[[VAL_31]] : memref<?x?xf64>
+// CHECK:           %[[VAL_48:.*]], %[[VAL_49:.*]] = gpu.create_dn_mat async {{\[}}%[[VAL_47]]] %[[VAL_6]], %[[VAL_7]], %[[VAL_38]] : memref<?x?xf64>
+// CHECK:           %[[VAL_50:.*]], %[[VAL_51:.*]] = gpu.spmm_buffer_size async {{\[}}%[[VAL_49]]] %[[VAL_42]], %[[VAL_44]], %[[VAL_46]], %[[VAL_48]]
+// CHECK:           %[[VAL_52:.*]], %[[VAL_53:.*]] = gpu.alloc async {{\[}}%[[VAL_51]]] (%[[VAL_50]]) : memref<?xi8>
+// CHECK:           %[[VAL_54:.*]] = gpu.spmm async {{\[}}%[[VAL_53]]] %[[VAL_42]], %[[VAL_44]], %[[VAL_46]], %[[VAL_48]], %[[VAL_52]] : memref<?xi8>
+// CHECK:           %[[VAL_55:.*]] = gpu.destroy_sp_mat async {{\[}}%[[VAL_54]]] %[[VAL_44]]
+// CHECK:           %[[VAL_56:.*]] = gpu.destroy_dn_mat async {{\[}}%[[VAL_55]]] %[[VAL_46]]
+// CHECK:           %[[VAL_57:.*]] = gpu.destroy_dn_mat async {{\[}}%[[VAL_56]]] %[[VAL_48]]
+// CHECK:           %[[VAL_58:.*]] = gpu.destroy_sparse_env async {{\[}}%[[VAL_57]]] %[[VAL_42]]
+// CHECK:           gpu.wait {{\[}}%[[VAL_58]]]
+// CHECK:           %[[VAL_59:.*]] = gpu.wait async
+// CHECK:           %[[VAL_60:.*]] = gpu.memcpy async {{\[}}%[[VAL_59]]] %[[VAL_34]], %[[VAL_38]] : memref<?x?xf64>, memref<?x?xf64>
+// CHECK:           %[[VAL_61:.*]] = gpu.dealloc async {{\[}}%[[VAL_60]]] %[[VAL_14]] : memref<?xindex>
+// CHECK:           %[[VAL_62:.*]] = gpu.dealloc async {{\[}}%[[VAL_61]]] %[[VAL_19]] : memref<?xindex>
+// CHECK:           %[[VAL_63:.*]] = gpu.dealloc async {{\[}}%[[VAL_62]]] %[[VAL_24]] : memref<?xf64>
+// CHECK:           %[[VAL_64:.*]] = gpu.dealloc async {{\[}}%[[VAL_63]]] %[[VAL_52]] : memref<?xi8>
+// CHECK:           %[[VAL_65:.*]] = gpu.dealloc async {{\[}}%[[VAL_64]]] %[[VAL_31]] : memref<?x?xf64>
+// CHECK:           %[[VAL_66:.*]] = gpu.dealloc async {{\[}}%[[VAL_65]]] %[[VAL_38]] : memref<?x?xf64>
+// CHECK:           gpu.wait {{\[}}%[[VAL_66]]]
+// CHECK:           return %[[VAL_2]] : tensor<?x?xf64>
+// CHECK:         }
+func.func @matmul(%A: tensor<?x?xf64, #CSR>, %B: tensor<?x?xf64>, %C_in: tensor<?x?xf64>) -> tensor<?x?xf64> {
+  %C_out = linalg.matmul
+      ins(%A, %B: tensor<?x?xf64, #CSR>, tensor<?x?xf64>)
+      outs(%C_in: tensor<?x?xf64>) -> tensor<?x?xf64>
+  return %C_out : tensor<?x?xf64>
+}


        


More information about the Mlir-commits mailing list