[Mlir-commits] [mlir] 76a80a0 - [mlir][sparse][gpu] sparsifier GPU libgen for SpGEMM in cuSparse

Aart Bik llvmlistbot at llvm.org
Thu Aug 10 14:52:25 PDT 2023


Author: Aart Bik
Date: 2023-08-10T14:52:16-07:00
New Revision: 76a80a080872350d70fc3b3d57b9db8bee54e1df

URL: https://github.com/llvm/llvm-project/commit/76a80a080872350d70fc3b3d57b9db8bee54e1df
DIFF: https://github.com/llvm/llvm-project/commit/76a80a080872350d70fc3b3d57b9db8bee54e1df.diff

LOG: [mlir][sparse][gpu] sparsifier GPU libgen for SpGEMM in cuSparse

With working integration end-to-end test

Reviewed By: K-Wu

Differential Revision: https://reviews.llvm.org/D157652

Added: 
    mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir

Modified: 
    mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index f9c35d8b14d2e0..98a61b19fc55ea 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -151,13 +151,25 @@ static gpu::AllocOp genAllocMemRef(OpBuilder &builder, Location loc, Value mem,
                                       token, dynamicSizes, ValueRange());
 }
 
+// Allocates a typed buffer on the host with given size.
+static Value genHostBuffer(OpBuilder &builder, Location loc, Type type,
+                           Value size) {
+  const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
+  return builder.create<memref::AllocOp>(loc, memTp, size).getResult();
+}
+
+// Allocates a typed buffer on the device with given size.
+static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Type type,
+                                   Value size, Value token) {
+  const auto memTp = MemRefType::get({ShapedType::kDynamic}, type);
+  return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
+                                      token, size, ValueRange());
+}
+
 // Allocates a void buffer on the device with given size.
 static gpu::AllocOp genAllocBuffer(OpBuilder &builder, Location loc, Value size,
                                    Value token) {
-  const auto memTp =
-      MemRefType::get({ShapedType::kDynamic}, builder.getI8Type());
-  return builder.create<gpu::AllocOp>(loc, TypeRange({memTp, token.getType()}),
-                                      token, size, ValueRange());
+  return genAllocBuffer(builder, loc, builder.getI8Type(), size, token);
 }
 
 /// Deallocates memory from the device.
@@ -198,7 +210,6 @@ static Value genTensorToMemref(PatternRewriter &rewriter, Location loc,
 /// assume that the first buffer is the one allocated for output. We create
 /// a set of properly chained asynchronous allocation/copy pairs to increase
 /// overlap before launching the kernel.
-/// TODO: the output assumption may be a bit too brittle
 static Value genParametersIn(OpBuilder &builder, Location loc,
                              SmallVectorImpl<Value> &scalars,
                              SmallVectorImpl<Value> &buffers,
@@ -571,6 +582,7 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   token = genDeallocMemRef(rewriter, loc, vecY, token);
   tokens.push_back(token);
   genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
     genHostUnregisterMemref(rewriter, loc, castR);
     if (memC)
@@ -579,7 +591,6 @@ rewriteSpMV(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
     genHostUnregisterMemref(rewriter, loc, castX);
     genHostUnregisterMemref(rewriter, loc, castY);
   }
-  tokens.clear();
 
   // Done.
   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, memY);
@@ -630,7 +641,6 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
     castB = genHostRegisterMemref(rewriter, loc, bufB);
     castBufC = genHostRegisterMemref(rewriter, loc, bufC);
   }
-
   Value rowA = genAllocCopy(rewriter, loc, memR, tokens);
   Value colA = memC ? genAllocCopy(rewriter, loc, memC, tokens) : Value();
   Value valA = genAllocCopy(rewriter, loc, memV, tokens);
@@ -702,6 +712,7 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   token = genDeallocMemRef(rewriter, loc, matC, token);
   tokens.push_back(token);
   genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
     genHostUnregisterMemref(rewriter, loc, castR);
     if (memC)
@@ -710,14 +721,179 @@ rewriteSpMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
     genHostUnregisterMemref(rewriter, loc, castB);
     genHostUnregisterMemref(rewriter, loc, castC);
   }
-  tokens.clear();
 
   // Done.
   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
   return success();
 }
 
-// Match and rewrite 2:4 SpMM kernels.
+// Match and rewrite SpGEMM kernel.
+static LogicalResult
+rewriteSpGEMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
+              GPUDataTransferStrategy gpuDataTransferStrategy) {
+  Location loc = op.getLoc();
+  Value a = op.getOperand(0);
+  Value b = op.getOperand(1);
+  Value c = op.getOperand(2); // we have C = AB
+  SmallVector<Value> tokens;
+
+  // Only CSR <- CSR x CSR supported.
+  bool isCOO = false;
+  SparseTensorType aTp = getSparseTensorType(a);
+  SparseTensorType bTp = getSparseTensorType(b);
+  SparseTensorType cTp = getSparseTensorType(c);
+  if (!isAdmissibleCSR(aTp) || !isAdmissibleCSR(bTp) || !isAdmissibleCSR(cTp))
+    return failure();
+
+  // Start sparse kernel and copy data from host to device.
+  //   a : amemR/amemC/amemV -> rowA,colA,valA
+  //   b : bmemR/bmemC/bmemV -> rowB,colB,valB
+  //   c : materializes
+  auto dnCType = cTp.getElementType();
+  Value nseA = rewriter.create<NumberOfEntriesOp>(loc, a);
+  Value nseB = rewriter.create<NumberOfEntriesOp>(loc, b);
+  Value szm = linalg::createOrFoldDimOp(rewriter, loc, a, 0);
+  Value szk = linalg::createOrFoldDimOp(rewriter, loc, a, 1);
+  Value szn = linalg::createOrFoldDimOp(rewriter, loc, b, 1);
+  Value amemR = genFirstPosOrCrds(rewriter, loc, a, isCOO, enableRT);
+  Value amemC = genSecondCrds(rewriter, loc, a, isCOO, enableRT);
+  Value amemV = genToValues(rewriter, loc, a);
+  Value bmemR = genFirstPosOrCrds(rewriter, loc, b, isCOO, enableRT);
+  Value bmemC = genSecondCrds(rewriter, loc, b, isCOO, enableRT);
+  Value bmemV = genToValues(rewriter, loc, b);
+  Value rowA = genAllocCopy(rewriter, loc, amemR, tokens);
+  Value colA = genAllocCopy(rewriter, loc, amemC, tokens);
+  Value valA = genAllocCopy(rewriter, loc, amemV, tokens);
+  Value rowB = genAllocCopy(rewriter, loc, bmemR, tokens);
+  Value colB = genAllocCopy(rewriter, loc, bmemC, tokens);
+  Value valB = genAllocCopy(rewriter, loc, bmemV, tokens);
+  genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
+
+  // Create sparse environment and sparse matrix/dense vector handles.
+  Type indexTp = rewriter.getIndexType();
+  Type spmatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
+  Type descTp = rewriter.getType<gpu::SparseSpGEMMOpHandleType>();
+  Type tokenTp = rewriter.getType<gpu::AsyncTokenType>();
+  Value token = genFirstWait(rewriter, loc);
+  Operation *spGenA =
+      genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szk, nseA,
+               rowA, colA, valA, isCOO, enableRT);
+  Value spMatA = spGenA->getResult(0);
+  token = spGenA->getResult(1);
+  Operation *spGenB =
+      genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szk, szn, nseB,
+               rowB, colB, valB, isCOO, enableRT);
+  Value spMatB = spGenB->getResult(0);
+  token = spGenB->getResult(1);
+
+  // Sparse matrix C materializes (also assumes beta == 0).
+  Value zero = constantIndex(rewriter, loc, 0);
+  Value one = constantIndex(rewriter, loc, 1);
+  Value mplus1 = rewriter.create<arith::AddIOp>(loc, szm, one);
+  auto e1 = genAllocBuffer(rewriter, loc, cTp.getPosType(), mplus1, token);
+  Value rowC = e1.getResult(0);
+  token = e1.getAsyncToken();
+  auto e2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), zero, token);
+  Value colC = e2.getResult(0);
+  token = e2.getAsyncToken();
+  auto e3 = genAllocBuffer(rewriter, loc, dnCType, zero, token);
+  Value valC = e3.getResult(0);
+  token = e3.getAsyncToken();
+  Operation *spGenC =
+      genSpMat(rewriter, loc, spmatHandleTp, tokenTp, token, szm, szn, zero,
+               rowC, colC, valC, isCOO, enableRT);
+  Value spMatC = spGenC->getResult(0);
+  token = spGenC->getResult(1);
+
+  // Precompute buffersizes for SpGEMM.
+  Operation *descOp =
+      rewriter.create<gpu::SpGEMMCreateDescrOp>(loc, descTp, tokenTp, token);
+  Value desc = descOp->getResult(0);
+  token = descOp->getResult(1);
+  Operation *work1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
+      loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
+      gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
+      valC, gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
+  Value bufferSz1 = work1->getResult(0);
+  token = work1->getResult(1);
+  auto buf1 = genAllocBuffer(rewriter, loc, bufferSz1, token);
+  Value buffer1 = buf1.getResult(0);
+  token = buf1.getAsyncToken();
+  Operation *work2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
+      loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
+      gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
+      bufferSz1, buffer1,
+      gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION);
+  token = work2->getResult(1);
+
+  // Compute step.
+  Operation *compute1 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
+      loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
+      gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType, zero,
+      valC, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
+  Value bufferSz2 = compute1->getResult(0);
+  token = compute1->getResult(1);
+  auto buf2 = genAllocBuffer(rewriter, loc, bufferSz2, token);
+  Value buffer2 = buf2.getResult(0);
+  token = buf2.getAsyncToken();
+  Operation *compute2 = rewriter.create<gpu::SpGEMMWorkEstimationOrComputeOp>(
+      loc, indexTp, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
+      gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType,
+      bufferSz2, buffer2, gpu::SpGEMMWorkEstimationOrComputeKind::COMPUTE);
+  token = compute2->getResult(1);
+
+  // Get sizes.
+  Operation *sizes = rewriter.create<gpu::SpGEMMGetSizeOp>(
+      loc, indexTp, indexTp, indexTp, tokenTp, token, spMatC);
+  Value nnz = sizes->getResult(2);
+  token = sizes->getResult(3);
+  auto a2 = genAllocBuffer(rewriter, loc, cTp.getCrdType(), nnz, token);
+  colC = a2.getResult(0);
+  token = a2.getAsyncToken();
+  auto a3 = genAllocBuffer(rewriter, loc, dnCType, nnz, token);
+  valC = a3.getResult(0);
+  token = a3.getAsyncToken();
+
+  // Update C with new pointers and copy final product back into C.
+  Operation *update = rewriter.create<gpu::SetCsrPointersOp>(
+      loc, tokenTp, token, spMatC, rowC, colC, valC);
+  token = update->getResult(0);
+  Operation *copy = rewriter.create<gpu::SpGEMMCopyOp>(
+      loc, tokenTp, token, desc, gpu::TransposeMode::NON_TRANSPOSE,
+      gpu::TransposeMode::NON_TRANSPOSE, spMatA, spMatB, spMatC, dnCType);
+  token = copy->getResult(0);
+
+  // Allocate buffers on host.
+  Value rowH = genHostBuffer(rewriter, loc, cTp.getPosType(), mplus1);
+  Value colH = genHostBuffer(rewriter, loc, cTp.getCrdType(), nnz);
+  Value valH = genHostBuffer(rewriter, loc, dnCType, nnz);
+
+  // Copy data back to host and free all the resoures.
+  token = rewriter.create<gpu::SpGEMMDestroyDescrOp>(loc, tokenTp, token, desc)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatA)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatB)
+              .getAsyncToken();
+  token = rewriter.create<gpu::DestroySpMatOp>(loc, tokenTp, token, spMatC)
+              .getAsyncToken();
+  token = genCopyMemRef(rewriter, loc, rowH, rowC, token);
+  token = genCopyMemRef(rewriter, loc, colH, colC, token);
+  token = genCopyMemRef(rewriter, loc, valH, valC, token);
+  tokens.push_back(token);
+  genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
+
+  // Done.
+  Value vt = rewriter.create<bufferization::ToTensorOp>(loc, valH);
+  Value rt = rewriter.create<bufferization::ToTensorOp>(loc, rowH);
+  Value ct = rewriter.create<bufferization::ToTensorOp>(loc, colH);
+  rewriter.replaceOpWithNewOp<PackOp>(op, c.getType(), vt, ValueRange{rt, ct});
+  return success();
+}
+
+// Match and rewrite 2:4 SpMM kernel.
 static LogicalResult
 rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
                 GPUDataTransferStrategy gpuDataTransferStrategy) {
@@ -748,7 +924,6 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
     castB = genHostRegisterMemref(rewriter, loc, bufB);
     castC = genHostRegisterMemref(rewriter, loc, bufC);
   }
-
   if (isZeroCopy) {
     matA = bufA;
     matB = bufB;
@@ -756,10 +931,11 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
   Value matC = genAllocCopy(rewriter, loc, bufC, tokens);
   genBlockingWait(rewriter, loc, tokens);
   tokens.clear();
+
+  // Create sparse environment and sparse matrix/dense vector handles.
   Value szm = linalg::createOrFoldDimOp(rewriter, loc, matA, 0);
   Value szk = linalg::createOrFoldDimOp(rewriter, loc, matB, 0);
   Value szn = linalg::createOrFoldDimOp(rewriter, loc, matC, 1);
-
   Type indexTp = rewriter.getIndexType();
   Type dnTensorHandleTp = rewriter.getType<gpu::SparseDnTensorHandleType>();
   Type spMatHandleTp = rewriter.getType<gpu::SparseSpMatHandleType>();
@@ -768,7 +944,6 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
   Operation *spGenA = rewriter.create<gpu::Create2To4SpMatOp>(
       loc, spMatHandleTp, tokenTp, token, szm, szk,
       gpu::Prune2To4SpMatFlag::PRUNE_AND_CHECK, matA);
-
   Value spMatA = spGenA->getResult(0);
   token = spGenA->getResult(1);
   auto dmatB = rewriter.create<gpu::CreateDnTensorOp>(
@@ -781,7 +956,6 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
       SmallVector<Value>{szm, szn});
   Value dnC = dmatC.getResult(0);
   token = dmatC.getAsyncToken();
-
   auto dmatCType = llvm::cast<ShapedType>(matC.getType()).getElementType();
 
   // Precompute buffersize for SpMM.
@@ -791,8 +965,8 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
       loc, bufferTypes, tokenTp, token, gpu::TransposeMode::NON_TRANSPOSE,
       gpu::TransposeMode::NON_TRANSPOSE, spMatA, dnB, dnC,
       /*computeType=*/dmatCType);
-
   token = bufferComp.getAsyncToken();
+
   Value bufferSz = bufferComp.getResult(0);
   auto buf = genAllocBuffer(rewriter, loc, bufferSz, token);
   Value buffer = buf.getResult(0);
@@ -824,11 +998,9 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
   token = rewriter.create<gpu::DestroyDnTensorOp>(loc, tokenTp, token, dnC)
               .getAsyncToken();
   SmallVector<Value> newDynamicSizes;
-
   token = genDeallocMemRef(rewriter, loc, buffer, token);
   token = genDeallocMemRef(rewriter, loc, buffer2, token);
   token = genDeallocMemRef(rewriter, loc, buffer3, token);
-
   if (!isZeroCopy)
     token = genDeallocMemRef(rewriter, loc, matA, token);
   if (!isZeroCopy)
@@ -837,12 +1009,14 @@ rewrite2To4SpMM(PatternRewriter &rewriter, linalg::GenericOp op,
   token = genDeallocMemRef(rewriter, loc, matC, token);
   tokens.push_back(token);
   genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
     genHostUnregisterMemref(rewriter, loc, castA);
     genHostUnregisterMemref(rewriter, loc, castB);
     genHostUnregisterMemref(rewriter, loc, castC);
   }
-  tokens.clear();
+
+  // Done.
   rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, bufC);
   return success();
 }
@@ -889,7 +1063,6 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   Value memR = genFirstPosOrCrds(rewriter, loc, c, isCOO, enableRT);
   Value memC = genSecondCrds(rewriter, loc, c, isCOO, enableRT);
   Value memV = genToValues(rewriter, loc, c);
-
   Value castB, castA, castR, castC, castV;
   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
     castB = genHostRegisterMemref(rewriter, loc, bufB);
@@ -899,7 +1072,6 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
       castC = genHostRegisterMemref(rewriter, loc, memC);
     castV = genHostRegisterMemref(rewriter, loc, memV);
   }
-
   if (isZeroCopy) {
     matA = bufA;
     matB = bufB;
@@ -930,8 +1102,8 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
                rowC, colC, valC, isCOO, enableRT);
   Value spMatC = spGenC->getResult(0);
   token = spGenC->getResult(1);
-
   auto dnCType = llvm::cast<ShapedType>(c.getType()).getElementType();
+
   // Precompute buffersize for SDDMM.
   auto bufferComp = rewriter.create<gpu::SDDMMBufferSizeOp>(
       loc, indexTp, tokenTp, token, dnA, dnB, spMatC, dnCType);
@@ -965,6 +1137,7 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
   token = genDeallocMemRef(rewriter, loc, valC, token);
   tokens.push_back(token);
   genBlockingWait(rewriter, loc, tokens);
+  tokens.clear();
   if (gpuDataTransferStrategy != GPUDataTransferStrategy::kRegularDMA) {
     genHostUnregisterMemref(rewriter, loc, castB);
     genHostUnregisterMemref(rewriter, loc, castA);
@@ -973,7 +1146,6 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
       genHostUnregisterMemref(rewriter, loc, castC);
     genHostUnregisterMemref(rewriter, loc, castV);
   }
-  tokens.clear();
 
   // Done.
   rewriter.replaceOpWithNewOp<sparse_tensor::LoadOp>(op, c);
@@ -986,7 +1158,7 @@ rewriteSDDMM(PatternRewriter &rewriter, linalg::GenericOp op, bool enableRT,
 
 /// Proof-of-concept rewriter. This rule generates a GPU implementation
 /// for each outermost forall loop generated by the sparse compiler.
-/// TODO: right works with parallelization-strategy=dense-outer-loop
+/// TODO: right now works with parallelization-strategy=dense-outer-loop
 ///       but give this its own flags in the future
 struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> {
   using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
@@ -1109,29 +1281,27 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
     AffineExpr i, j, k;
     bindDims(getContext(), i, j, k);
 
-    // TODO: more robust patterns, tranposed versions, more kernels...
-    // TODO: identify alpha and beta and pass them to the CUDA calls
+    // TODO: more robust patterns, tranposed versions, more kernels,
+    //       identify alpha and beta and pass them to the CUDA calls.
 
     // Recognize a SpMV kernel.
     if (numLoops == 2 && numTensors == 3 &&
         linalg::isParallelIterator(iteratorTypes[0]) &&
         linalg::isReductionIterator(iteratorTypes[1]) &&
-        // TODO: add transposed {i, j}
         maps == infer({{i, j}, {j}, {i}}) && matchSumOfMultOfArgs(op)) {
       return rewriteSpMV(rewriter, op, enableRT, gpuDataTransferStrategy);
     }
 
-    // Recognize a SpMM kernel.
+    // Recognize a SpGEMM, 2:4-SpMM, or SpMM kernel.
     if (numLoops == 3 && numTensors == 3 &&
         linalg::isParallelIterator(iteratorTypes[0]) &&
         linalg::isParallelIterator(iteratorTypes[1]) &&
         linalg::isReductionIterator(iteratorTypes[2]) &&
-        // TODO: add transposed {i, k}, {k, j}
-        // TODO: maybe add transposed {i, j} in future
         maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
+      if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
+        return rewriteSpGEMM(rewriter, op, enableRT, gpuDataTransferStrategy);
       if (op->getAttr("DENSE24"))
         return rewrite2To4SpMM(rewriter, op, gpuDataTransferStrategy);
-
       return rewriteSpMM(rewriter, op, enableRT, gpuDataTransferStrategy);
     }
 
@@ -1140,8 +1310,6 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
         linalg::isParallelIterator(iteratorTypes[0]) &&
         linalg::isParallelIterator(iteratorTypes[1]) &&
         linalg::isReductionIterator(iteratorTypes[2]) &&
-        // TODO: add transposed {i, k}, {k, j}
-        // TODO: maybe add transposed {i, j} in future
         maps == infer({{i, k}, {k, j}, {i, j}}) &&
         matchSumReductionOfMulUnary(op)) {
       return rewriteSDDMM(rewriter, op, enableRT, gpuDataTransferStrategy);

diff  --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir
new file mode 100644
index 00000000000000..a39fdd8dc0ac6a
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-gemm-lib.mlir
@@ -0,0 +1,81 @@
+//
+// NOTE: this test requires gpu-sm80
+//
+// without RT lib:
+//
+// RUN: mlir-opt %s \
+// RUN:   --sparse-compiler="enable-runtime-library=false enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71"  \
+// RUN: | mlir-cpu-runner \
+// RUN:   --shared-libs=%mlir_cuda_runtime \
+// RUN:   --shared-libs=%mlir_runner_utils \
+// RUN:   --e main --entry-point-result=void \
+// RUN: | FileCheck %s
+
+#CSR = #sparse_tensor.encoding<{
+  lvlTypes = [ "dense", "compressed" ],
+  posWidth = 32,
+  crdWidth = 32
+}>
+
+module {
+  llvm.func @mgpuCreateSparseEnv()
+  llvm.func @mgpuDestroySparseEnv()
+
+  // Computes C = A x B with A,B,C sparse CSR.
+  func.func @matmulCSR(%A: tensor<8x8xf32, #CSR>,
+                       %B: tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR> {
+    %init = bufferization.alloc_tensor() : tensor<8x8xf32, #CSR>
+    %C = linalg.matmul
+      ins(%A, %B: tensor<8x8xf32, #CSR>,
+                  tensor<8x8xf32, #CSR>)
+      outs(%init: tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR>
+    return %C: tensor<8x8xf32, #CSR>
+  }
+
+  //
+  // Main driver.
+  //
+  func.func @main() {
+    llvm.call @mgpuCreateSparseEnv(): () -> ()
+
+    %c0 = arith.constant 0 : index
+    %f0 = arith.constant 0.0 : f32
+
+    %t = arith.constant dense<[
+       [ 1.0,  0.0,  2.0,  0.0,  0.0,  0.0,  0.0,  3.0],
+       [ 0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0,  0.0],
+       [ 0.0,  0.0,  4.0,  0.0,  0.0,  0.0,  0.0,  0.0],
+       [ 0.0,  0.0,  0.0,  5.0,  0.0,  0.0,  0.0,  0.0],
+       [ 0.0,  0.0,  0.0,  0.0,  6.0,  0.0,  0.0,  0.0],
+       [ 0.0,  7.0,  8.0,  0.0,  0.0,  0.0,  0.0,  9.0],
+       [ 0.0,  0.0, 10.0,  0.0,  0.0,  0.0, 11.0, 12.0],
+       [ 0.0, 13.0, 14.0,  0.0,  0.0,  0.0, 15.0, 16.0]
+    ]> : tensor<8x8xf32>
+    %Acsr = sparse_tensor.convert %t : tensor<8x8xf32> to tensor<8x8xf32, #CSR>
+
+    %Ccsr = call @matmulCSR(%Acsr, %Acsr) : (tensor<8x8xf32, #CSR>,
+                                             tensor<8x8xf32, #CSR>) -> tensor<8x8xf32, #CSR>
+
+    //
+    // Verify computed result (expected output, with only 20 nonzeros).
+    //
+    // CHECK:    ( ( 1, 39, 52, 0, 0, 0, 45, 51 ),
+    // CHECK-SAME: ( 0, 0, 0, 0, 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 0, 16, 0, 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 0, 0, 25, 0, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 0, 0, 0, 36, 0, 0, 0 ),
+    // CHECK-SAME: ( 0, 117, 158, 0, 0, 0, 135, 144 ),
+    // CHECK-SAME: ( 0, 156, 318, 0, 0, 0, 301, 324 ),
+    // CHECK-SAME: ( 0, 208, 430, 0, 0, 0, 405, 436 ) )
+    // CHECK-NEXT: 20
+    %d = sparse_tensor.convert %Ccsr : tensor<8x8xf32, #CSR> to tensor<8x8xf32>
+    %v = vector.transfer_read %d[%c0, %c0], %f0: tensor<8x8xf32>, vector<8x8xf32>
+    vector.print %v : vector<8x8xf32>
+    %nnz = sparse_tensor.number_of_entries %Ccsr : tensor<8x8xf32, #CSR>
+    %x = sparse_tensor.number_of_entries %Ccsr : tensor<8x8xf32, #CSR>
+    vector.print %nnz : index
+
+    llvm.call @mgpuDestroySparseEnv(): () -> ()
+    return
+  }
+}


        


More information about the Mlir-commits mailing list