[Mlir-commits] [mlir] ac30f48 - [mlir][sparse][gpu]fix various cusparseLt bugs

Kun Wu llvmlistbot at llvm.org
Mon Jun 12 16:48:58 PDT 2023


Author: Kun Wu
Date: 2023-06-12T23:48:49Z
New Revision: ac30f48e377b80a83953e346054ffe3288276713

URL: https://github.com/llvm/llvm-project/commit/ac30f48e377b80a83953e346054ffe3288276713
DIFF: https://github.com/llvm/llvm-project/commit/ac30f48e377b80a83953e346054ffe3288276713.diff

LOG: [mlir][sparse][gpu]fix various cusparseLt bugs

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D152489

Added: 
    

Modified: 
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 689a705350c75..8e922f54b5e27 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -263,36 +263,30 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       {llvmPointerType, llvmInt32Type, llvmPointerType, llvmPointerType,
        llvmPointerType, llvmInt32Type, llvmPointerType,
        llvmPointerType /* void *stream */}};
-  FunctionCallBuilder spMMBufferSizeCallBuilder = {
+  FunctionCallBuilder createSpMMBufferSizeCallBuilder = {
       "mgpuSpMMBufferSize",
       llvmIntPtrType,
       {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
        llvmPointerType, llvmPointerType, llvmInt32Type,
        llvmPointerType /* void *stream */}};
-  FunctionCallBuilder spMMCallBuilder = {
+  FunctionCallBuilder createSpMMCallBuilder = {
       "mgpuSpMM",
       llvmVoidType,
       {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
        llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType,
        llvmPointerType /* void *stream */}};
-  FunctionCallBuilder SDDMMBufferSizeCallBuilder = {
+  FunctionCallBuilder createSDDMMBufferSizeCallBuilder = {
       "mgpuSDDMMBufferSize",
       llvmIntPtrType,
       {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
        llvmPointerType, llvmPointerType, llvmInt32Type,
        llvmPointerType /* void *stream */}};
-  FunctionCallBuilder SDDMMCallBuilder = {
+  FunctionCallBuilder createSDDMMCallBuilder = {
       "mgpuSDDMM",
       llvmVoidType,
       {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
        llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType,
        llvmPointerType /* void *stream */}};
-  FunctionCallBuilder AssertSparseLTEnvHandleSizeCallBuilder = {
-      "mgpuAssertSparseLTEnvHandleSize", llvmVoidType, {}};
-  FunctionCallBuilder AssertSparseLTSpMatHandleSizeCallBuilder = {
-      "mgpuAssertSparseLTSpMatHandleSize", llvmVoidType, {}};
-  FunctionCallBuilder AssertSparseLTDnMatHandleSizeCallBuilder = {
-      "mgpuAssertSparseLtDnMatHandleSize", llvmVoidType, {}};
   FunctionCallBuilder createSparseLtEnvCallBuilder = {
       "mgpuCreateSparseLtEnv",
       llvmVoidType,
@@ -319,16 +313,17 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       llvmVoidType,
       {llvmPointerType, llvmPointerType, llvmIntPtrType, llvmIntPtrType,
        llvmPointerType, llvmInt32Type, llvmPointerType /* void *stream */}};
-  FunctionCallBuilder cuSparseLtSpmmBufferSizeBuilder = {
+  FunctionCallBuilder createCuSparseLtSpMMBufferSizeBuilder = {
       "mgpuCuSparseLtSpMMBufferSize",
       llvmVoidType,
-      {llvmPointerType, llvmPointerType, llvmPointerType,
+      {llvmPointerType, llvmPointerType, llvmInt32Type, llvmInt32Type,
+       llvmPointerType, llvmPointerType, llvmPointerType, llvmInt32Type,
        llvmPointerType /*void *stream*/}};
-  FunctionCallBuilder cuSparseLtSpmmBuilder = {
+  FunctionCallBuilder createCuSparseLtSpMMBuilder = {
       "mgpuCuSparseLtSpMM",
       llvmVoidType,
       {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
-       llvmInt32Type, llvmPointerType, llvmPointerType, llvmPointerType,
+       llvmPointerType, llvmPointerType, llvmPointerType,
        llvmPointerType /*void *stream*/}};
 };
 
@@ -1397,13 +1392,6 @@ static Value genConstInt32From(OpBuilder &builder, Location loc, T TValue) {
                                           static_cast<int32_t>(TValue));
 }
 
-static Value genConstInt32FromComputeMode(OpBuilder &builder, Location loc,
-                                          Type computeType) {
-  auto computeTypeInt = getCuSparseDataTypeFrom(computeType);
-  auto computeTypeConst = genConstInt32From(builder, loc, computeTypeInt);
-  return computeTypeConst;
-}
-
 LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite(
     gpu::CreateSparseEnvOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -1416,8 +1404,7 @@ LogicalResult ConvertCreateSparseEnvOpToGpuRuntimeCallPattern::matchAndRewrite(
   // 2:4 sparsity
   Value handle;
   if (isSpMMCusparseLtOp(op.getEnv())) {
-    // Assert the size is 11024 bytes
-    AssertSparseLTEnvHandleSizeCallBuilder.create(loc, rewriter, {});
+    // CUDA runner asserts the size is 11024 bytes.
     auto handleSz = rewriter.create<LLVM::ConstantOp>(
         loc, getIndexType(), rewriter.getIndexAttr(11024));
     handle = rewriter.create<LLVM::AllocaOp>(loc, llvmInt8PointerType,
@@ -1484,7 +1471,6 @@ LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
   if (dims.size() == 2) {
     if (isSpMMCusparseLtOp(op.getDnTensor())) {
       auto envHandle = adaptor.getEnv();
-      AssertSparseLTDnMatHandleSizeCallBuilder.create(loc, rewriter, {});
       auto handleSz = rewriter.create<LLVM::ConstantOp>(
           loc, getIndexType(), rewriter.getIndexAttr(11032));
       handle = rewriter.create<LLVM::AllocaOp>(loc, llvmInt8PointerType,
@@ -1660,10 +1646,10 @@ LogicalResult ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern::matchAndRewrite(
     pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
   Type dType =
       llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
-  auto dtp = genConstInt32From(rewriter, loc, getCuSparseLtDataTypeFrom(dType));
+  auto dtp = genConstInt32From(rewriter, loc, getCuSparseDataTypeFrom(dType));
   auto envHandle = adaptor.getEnv();
 
-  AssertSparseLTSpMatHandleSizeCallBuilder.create(loc, rewriter, {});
+  // CUDA runner asserts the size is 44104 bytes.
   auto handleSz = rewriter.create<LLVM::ConstantOp>(
       loc, getIndexType(), rewriter.getIndexAttr(44104));
   Value handle = rewriter.create<LLVM::AllocaOp>(loc, llvmInt8PointerType,
@@ -1707,8 +1693,8 @@ LogicalResult ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
     return failure();
   Location loc = op.getLoc();
   auto modeA = genConstInt32From(rewriter, loc, op.getModeA());
-  auto computeType =
-      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
+  auto computeType = genConstInt32From(
+      rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
   auto stream = adaptor.getAsyncDependencies().front();
   auto bufferSize =
       spMVBufferSizeCallBuilder
@@ -1728,8 +1714,8 @@ LogicalResult ConvertSpMVOpToGpuRuntimeCallPattern::matchAndRewrite(
     return failure();
   Location loc = op.getLoc();
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
-  auto computeType =
-      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
+  auto computeType = genConstInt32From(
+      rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
   auto stream = adaptor.getAsyncDependencies().front();
   Value pBuf =
       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1753,10 +1739,10 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
   auto stream = adaptor.getAsyncDependencies().front();
-  auto computeType =
-      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
   Value bufferSize;
   if (is2To4Sparsity(op.getSpmatA())) {
+    auto computeType = genConstInt32From(
+        rewriter, loc, getCuSparseLtDataTypeFrom(adaptor.getComputeType()));
     auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
                                                    rewriter.getIndexAttr(3));
     bufferSize = rewriter.create<LLVM::AllocaOp>(loc, llvmInt64PointerType,
@@ -1764,13 +1750,17 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
     bufferSize =
         rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, bufferSize);
 
-    cuSparseLtSpmmBufferSizeBuilder
+    createCuSparseLtSpMMBufferSizeBuilder
         .create(loc, rewriter,
-                {bufferSize, adaptor.getEnv(), adaptor.getSpmatA(), stream})
+                {bufferSize, adaptor.getEnv(), modeA, modeB,
+                 adaptor.getSpmatA(), adaptor.getDnmatB(), adaptor.getDnmatC(),
+                 computeType, stream})
         .getResult();
     rewriter.replaceOp(op, {bufferSize, stream});
   } else {
-    bufferSize = spMMBufferSizeCallBuilder
+    auto computeType = genConstInt32From(
+        rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
+    bufferSize = createSpMMBufferSizeCallBuilder
                      .create(loc, rewriter,
                              {adaptor.getEnv(), modeA, modeB,
                               adaptor.getSpmatA(), adaptor.getDnmatB(),
@@ -1790,10 +1780,10 @@ LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
   Location loc = op.getLoc();
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
-  auto computeType =
-      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
+  auto computeType = genConstInt32From(
+      rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
   auto stream = adaptor.getAsyncDependencies().front();
-  auto bufferSize = SDDMMBufferSizeCallBuilder
+  auto bufferSize = createSDDMMBufferSizeCallBuilder
                         .create(loc, rewriter,
                                 {adaptor.getEnv(), modeA, modeB,
                                  adaptor.getDnmatA(), adaptor.getDnmatB(),
@@ -1812,8 +1802,8 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
   Location loc = op.getLoc();
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
-  auto computeType =
-      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
+  auto computeType = genConstInt32From(
+      rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
 
   auto stream = adaptor.getAsyncDependencies().front();
 
@@ -1826,20 +1816,19 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
         pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
       pBufs.push_back(pBuf);
     }
-    cuSparseLtSpmmBuilder.create(loc, rewriter,
-                                 {adaptor.getEnv(), adaptor.getSpmatA(),
-                                  adaptor.getDnmatB(), adaptor.getDnmatC(),
-                                  computeType, pBufs[0], pBufs[1], pBufs[2],
-                                  stream});
+    createCuSparseLtSpMMBuilder.create(
+        loc, rewriter,
+        {adaptor.getEnv(), adaptor.getSpmatA(), adaptor.getDnmatB(),
+         adaptor.getDnmatC(), pBufs[0], pBufs[1], pBufs[2], stream});
   } else {
     Value pBuf = MemRefDescriptor(adaptor.getBuffers().front())
                      .allocatedPtr(rewriter, loc);
     if (!getTypeConverter()->useOpaquePointers())
       pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
-    spMMCallBuilder.create(loc, rewriter,
-                           {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
-                            adaptor.getDnmatB(), adaptor.getDnmatC(),
-                            computeType, pBuf, stream});
+    createSpMMCallBuilder.create(
+        loc, rewriter,
+        {adaptor.getEnv(), modeA, modeB, adaptor.getSpmatA(),
+         adaptor.getDnmatB(), adaptor.getDnmatC(), computeType, pBuf, stream});
   }
   rewriter.replaceOp(op, {stream});
   return success();
@@ -1860,8 +1849,8 @@ LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
-  auto computeType =
-      genConstInt32FromComputeMode(rewriter, loc, adaptor.getComputeType());
+  auto computeType = genConstInt32From(
+      rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
   auto stream = adaptor.getAsyncDependencies().front();
@@ -1869,10 +1858,10 @@ LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
   if (!getTypeConverter()->useOpaquePointers())
     pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
-  SDDMMCallBuilder.create(loc, rewriter,
-                          {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(),
-                           adaptor.getDnmatB(), adaptor.getSpmatC(),
-                           computeType, pBuf, stream});
+  createSDDMMCallBuilder.create(
+      loc, rewriter,
+      {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(), adaptor.getDnmatB(),
+       adaptor.getSpmatC(), computeType, pBuf, stream});
   rewriter.replaceOp(op, {stream});
   return success();
 }

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index c811b72a21403..47fee38e84d2c 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -468,13 +468,13 @@ mgpuSDDMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
 
 struct cusparseLtSpMatHandleAndData {
   cusparseLtMatDescriptor_t mat;
-  void *values{nullptr};
-  // TODO: the following is associated with the SpMM operator rather than the
-  // sparse matrix. Create workspace buffers and pass them to the SpMM
+  // TODO: the following three are associated with the SpMM operator rather than
+  // the sparse matrix. Create workspace buffers and pass them to the SpMM
   // execution.
   cusparseLtMatmulAlgSelection_t alg_sel;
   cusparseLtMatmulPlan_t plan;
   cusparseLtMatmulDescriptor_t matmul;
+  void *values{nullptr};
 };
 
 struct cusparseLtDnMatHandleAndData {
@@ -482,24 +482,15 @@ struct cusparseLtDnMatHandleAndData {
   void *values{nullptr};
 };
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuAssertSparseLTEnvHandleSize() {
-  assert(sizeof(cusparseLtHandle_t) == 11024);
-}
-
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuAssertSparseLtSpMatHandleSize() {
-  return assert(sizeof(cusparseLtSpMatHandleAndData) == 44104);
-}
-
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSparseLtDnMatHandleSize() {
-  return assert(sizeof(cusparseLtDnMatHandleAndData) == 11032);
-}
+static_assert(sizeof(cusparseLtHandle_t) == 11024);
+static_assert(sizeof(cusparseLtSpMatHandleAndData) == 44104);
+static_assert(sizeof(cusparseLtDnMatHandleAndData) == 11032);
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
 mgpuCreateSparseLtEnv(void *h, CUstream /*stream*/) {
   // note that cuSparseLt still uses cusparseStatus_t
   CUSPARSE_REPORT_IF_ERROR(
       cusparseLtInit(reinterpret_cast<cusparseLtHandle_t *>(h)))
-  return;
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
@@ -510,15 +501,16 @@ mgpuDestroySparseLtEnv(void *h, CUstream /*stream*/) {
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
 mgpuCreateCuSparseLtDnMat(void *dh, void *h, intptr_t rows, intptr_t cols,
-                          void *values, int32_t dw, CUstream /*stream*/) {
-  cusparseLtMatDescriptor_t mat;
+                          void *values, int32_t dtp, CUstream /*stream*/) {
   auto handle = reinterpret_cast<cusparseLtHandle_t *>(h);
+  // CusparseLt expects the descriptors to be zero-initialized.
+  memset(dh, 0, sizeof(cusparseLtDnMatHandleAndData));
   auto dnmat_handle = reinterpret_cast<cusparseLtDnMatHandleAndData *>(dh);
-  cudaDataType_t dtp = dataTp(dw);
+  auto dTp = static_cast<cudaDataType_t>(dtp);
   // assuming row-major when deciding lda
   CUSPARSE_REPORT_IF_ERROR(cusparseLtDenseDescriptorInit(
-      handle, &(dh->mat), rows, cols, /*lda=*/cols,
-      /*alignment=*/16, dtp, CUSPARSE_ORDER_ROW))
+      handle, &(dnmat_handle->mat), rows, cols, /*lda=*/cols,
+      /*alignment=*/16, dTp, CUSPARSE_ORDER_ROW))
   dnmat_handle->values = values;
 }
 
@@ -526,94 +518,99 @@ mgpuCreateCuSparseLtDnMat(void *dh, void *h, intptr_t rows, intptr_t cols,
 // cusparseLt
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
 mgpuDestroyCuSparseLtSpMat(void *m, CUstream /*stream*/) {
-  auto matAndData = reinterpret_cast<cusparseLtSpMatHandleAndData>(m);
+  auto matAndData = reinterpret_cast<cusparseLtSpMatHandleAndData *>(m);
+  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(matAndData->mat)))
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
 mgpuDestroyCuSparseLtDnMat(void *m, CUstream /*stream*/) {
-  auto matAndData = reinterpret_cast<cusparseLtDnMatHandleAndData>(m);
-  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(mat->mat)))
+  auto matAndData = reinterpret_cast<cusparseLtDnMatHandleAndData *>(m);
+  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(matAndData->mat)))
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
 mgpuCusparseLtCreate2To4SpMat(void *sh, void *h, intptr_t rows, intptr_t cols,
-                              void *values, int32_t dw, CUstream /*stream*/) {
+                              void *values, int32_t dtp, CUstream /*stream*/) {
   auto spmat_handle = reinterpret_cast<cusparseLtSpMatHandleAndData *>(sh);
+  // CusparseLt expects the descriptors to be zero-initialized.
+  memset(spmat_handle, 0, sizeof(cusparseLtSpMatHandleAndData));
   spmat_handle->values = values;
   auto handle = reinterpret_cast<cusparseLtHandle_t *>(h);
-  cudaDataType_t dtp = dataTp_cusparseLt(dw);
+  auto dTp = static_cast<cudaDataType_t>(dtp);
   // assuming row-major when deciding lda
   CUSPARSE_REPORT_IF_ERROR(cusparseLtStructuredDescriptorInit(
-      handle, &(sh->mat), rows, cols, /*ld=*/cols, /*alignment=*/16, dtp,
-      CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT))
+      handle, &(spmat_handle->mat), rows, cols, /*ld=*/cols, /*alignment=*/16,
+      dTp, CUSPARSE_ORDER_ROW, CUSPARSELT_SPARSITY_50_PERCENT))
 }
 
 // Several things are being done in this stage, algorithm selection, planning,
 // and returning workspace and compressed matrices data buffer sizes.
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
-mgpuCuSparseLtSpMMBufferSize(void *workspace_size, void *compressed_size,
-                             void *compressed_buffer_size, void *h, void *a,
+mgpuCuSparseLtSpMMBufferSize(void *bs, void *h, int32_t ma, int32_t mb, void *a,
+                             void *b, void *c, int32_t ctp,
                              CUstream /*stream*/) {
   // TODO: support more advanced settings, e.g., the input right operand is a
   // sparse matrix assuming matA is the sparse matrix
   auto handle = reinterpret_cast<cusparseLtHandle_t *>(h);
   auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a);
+  auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b);
+  auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c);
+  auto workspace_size = reinterpret_cast<size_t *>(bs);
+  auto compressed_size = &(reinterpret_cast<size_t *>(bs)[1]);
+  auto compressed_buffer_size = &(reinterpret_cast<size_t *>(bs)[2]);
+  auto cTp = static_cast<cusparseComputeType>(ctp);
 
-  CHECK_CUSPARSE(cusparseLtMatmulAlgSelectionInit(
-      handle, &(matWithData.alg_sel), &matmul, CUSPARSELT_MATMUL_ALG_DEFAULT))
+  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
+  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulDescriptorInit(
+      handle, &(matA->matmul), modeA, modeB, &(matA->mat), &(matB->mat),
+      &(matC->mat), &(matC->mat), cTp))
+  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSelectionInit(
+      handle, &(matA->alg_sel), &(matA->matmul), CUSPARSELT_MATMUL_ALG_DEFAULT))
   int alg = 0;
-  CHECK_CUSPARSE(cusparseLtMatmulAlgSetAttribute(
-      handle, &(matWithData.alg_sel), CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg,
+  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulAlgSetAttribute(
+      handle, &(matA->alg_sel), CUSPARSELT_MATMUL_ALG_CONFIG_ID, &alg,
       sizeof(alg)))
-  // TODO: add transpose support
-  CHECK_CUSPARSE(cusparseLtMatmulDescriptorInit(
-      handle, &(matA.matmul), c, CUSPARSE_OPERATION_NON_TRANSPOSE, &(matA->mat),
-      &matB, &matC, &matC, compute_type))
-  CHECK_CUSPARSE(cusparseLtMatmulPlanInit(handle, &(matWithData.plan), &matmul,
-                                          &(matWithData.alg_sel)))
-
-  CHECK_CUSPARSE(
-      cusparseLtMatmulGetWorkspace(handle, &(matA.plan), workspace_size))
-  CHECK_CUSPARSE(cusparseLtSpMMACompressedSize(
-      handle, &(matA.plan), compressed_size, compressed_buffer_size))
+
+  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanInit(
+      handle, &(matA->plan), &(matA->matmul), &(matA->alg_sel)))
+
+  CUSPARSE_REPORT_IF_ERROR(
+      cusparseLtMatmulGetWorkspace(handle, &(matA->plan), workspace_size))
+  CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMACompressedSize(
+      handle, &(matA->plan), compressed_size, compressed_buffer_size))
 
   // avoid zero-alloc
   *workspace_size = (*workspace_size == 0 ? 1 : *workspace_size);
   *compressed_size = (*compressed_size == 0 ? 1 : *compressed_size);
   *compressed_buffer_size =
       (*compressed_buffer_size == 0 ? 1 : *compressed_buffer_size);
-  return;
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
-mgpuCuSparseLtSpMM(void *alg_sel, void *plan, void *matmul, void *h, void *a,
-                   void *b, void *c, int32_t dw, void *buf, void *dA_compressed,
-                   void *dA_compressedBuffer, CUstream stream) {
+mgpuCuSparseLtSpMM(void *h, void *a, void *b, void *c, void *d_workspace,
+                   void *dA_compressed, void *dA_compressedBuffer,
+                   CUstream stream) {
   auto handle = reinterpret_cast<cusparseLtHandle_t *>(h);
   auto matA = reinterpret_cast<cusparseLtSpMatHandleAndData *>(a);
   auto matB = reinterpret_cast<cusparseLtDnMatHandleAndData *>(b);
   auto matC = reinterpret_cast<cusparseLtDnMatHandleAndData *>(c);
 
-  cusparseLtMatmulAlgSelection_t alg_sel;
-  cusparseLtMatmulPlan_t plan;
-  cusparseLtMatmulDescriptor_t matmul;
-
-  ALPHABETA(dw, alpha, beta)
-
-  CHECK_CUSPARSE(cusparseLtSpMMACompress(handle, &(matA->plan), &(matA->values),
-                                         dA_compressed, dA_compressedBuffer,
-                                         stream))
+  ALPHABETA(CUDA_R_32F, alpha, beta)
+  CUSPARSE_REPORT_IF_ERROR(
+      cusparseLtSpMMACompress(handle, &(matA->plan), (matA->values),
+                              dA_compressed, dA_compressedBuffer, stream))
 
   // TODO: add support to multi-stream execution
   // Perform the matrix multiplication. D = A*B+C using C==D for now
-  CHECK_CUSPARSE(
-      cusparseLtMatmul(handle, reinterpret_cast<cusparseLtMatmulPlan_t *>(plan),
-                       &alpha, dA_compressed, dB, &beta, matC->values,
-                       /*dD*/ matC->values, d_workspace, &stream, 1))
+  CUSPARSE_REPORT_IF_ERROR(
+      cusparseLtMatmul(handle, &(matA->plan), alphap, dA_compressed,
+                       matB->values, betap, matC->values,
+                       /*dD*/ matC->values, d_workspace, nullptr, 0))
 
-  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(mat->mat)))
+  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatDescriptorDestroy(&(matA->mat)))
   // destroy the plan associated with the sparse matrix
-  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanDestroy(&(mat->plan)))
+  CUSPARSE_REPORT_IF_ERROR(cusparseLtMatmulPlanDestroy(&(matA->plan)))
 }
 
 #endif // MLIR_ENABLE_CUDA_CUSPARSELT


        


More information about the Mlir-commits mailing list