[Mlir-commits] [mlir] 95a6c50 - [mlir][sparse][gpu] add set csr pointers, remove estimate op, fix bugs

Aart Bik llvmlistbot at llvm.org
Thu Aug 10 13:52:58 PDT 2023


Author: Aart Bik
Date: 2023-08-10T13:52:47-07:00
New Revision: 95a6c509c9ec9b87ca951de4cded9a7807058ae6

URL: https://github.com/llvm/llvm-project/commit/95a6c509c9ec9b87ca951de4cded9a7807058ae6
DIFF: https://github.com/llvm/llvm-project/commit/95a6c509c9ec9b87ca951de4cded9a7807058ae6.diff

LOG: [mlir][sparse][gpu] add set csr pointers, remove estimate op, fix bugs

Rationale:
Since we only support default algorithm for SpGEMM, we can remove the
estimate op (for now at least). This also introduces the set csr pointers
op, and fixes a few bugs in the existing lowering for the SpGEMM breakdown.
This revision paves the way for actual recognition of SpGEMM in the sparsifier.

Reviewed By: K-Wu

Differential Revision: https://reviews.llvm.org/D157645

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
    mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
    mlir/test/Dialect/GPU/sparse-roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 8dc88663c0c8c2..498e8c37049a8d 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1561,6 +1561,7 @@ def GPU_SubgroupMmaElementwiseOp : GPU_Op<"subgroup_mma_elementwise",
 // Operation on sparse matrices, called from the host
 // (currently lowers to cuSparse for CUDA only, no ROCM lowering).
 //
+
 def GPU_CreateDnTensorOp : GPU_Op<"create_dn_tensor", [GPU_AsyncOpInterface, AttrSizedOperandSegments]> {
   let summary = "Create dense tensor operation";
   let description = [{
@@ -2234,9 +2235,9 @@ def GPU_SpGEMMWorkEstimationOrComputeOp : GPU_Op<"spgemm_work_estimation_or_comp
     Example:
 
     ```mlir
-    %bufferSz, %token = gpu.spgemm_work_estimation_or_compute async [%dep]{COMPUTE}
+    %bufferSz, %token = gpu.spgemm_work_estimation_or_compute async [%dep] {COMPUTE}
                           %desc, %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE},
-                          %spmatC, ALG2, %spgemmDesc, %c0, %alloc: f32 into
+                          %spmatC, %spgemmDesc, %c0, %alloc: f32 into
                           memref<0xi8>
     ```
 
@@ -2283,69 +2284,11 @@ def GPU_SpGEMMWorkEstimationOrComputeOp : GPU_Op<"spgemm_work_estimation_or_comp
   }];
 }
 
-def GPU_SpGEMMEstimateMemoryOp : GPU_Op<"spgemm_estimate_memory", [GPU_AsyncOpInterface]> {
-  let summary = "SpGEMM estimate memory operation";
-  let description = [{
-    The `gpu.spgemm_estimate_memory` is used for both determining the buffer
-    size and performing the actual computation.
-
-    If the `async` keyword is present, the op is executed asynchronously (i.e.
-    it does not block until the execution has finished on the device). In
-    that case, it returns a `!gpu.async.token` in addition to the environment.
-
-    Example:
-
-    ```mlir
-    %bufferSz3, %dummy, %token = gpu.spgemm_estimate_memory async [%dep] %spmatA, %spmatB, %spmatC, ALG2, %spgemmDesc, %c0, %c0, %alloc: f32 into memref<0xi8>
-    ```
-  }];
-
-  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                       GPU_SparseSpGEMMOpHandle:$desc,
-                       GPU_TransposeModeAttr:$modeA,
-                       GPU_TransposeModeAttr:$modeB,
-                       GPU_SparseSpMatHandle:$spmatA,
-                       GPU_SparseSpMatHandle:$spmatB,
-                       GPU_SparseSpMatHandle:$spmatC,
-                       TypeAttr:$computeType,
-                       Index:$bufferSz3,
-                       AnyMemRef:$buffer3,
-                       Index:$bufferSz2);
-  let results = (outs Index:$bufferSz3New,
-                      Index:$bufferSz2New,
-                      Optional<GPU_AsyncToken>:$asyncToken);
-
-  let builders = [OpBuilder<(ins
-    "Type":$bufferSz3New,
-    "Type":$bufferSz2New,
-    "Type":$asyncToken,
-    "ValueRange":$asyncDependencies,
-    "Value":$desc,
-    "Value":$spmatA,
-    "Value":$spmatB,
-    "Value":$spmatC,
-    "Type":$computeType,
-    "Value":$bufferSz3,
-    "Value":$buffer3,
-    "Value":$bufferSz2), [{
-  auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
-  auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
-  return build($_builder, $_state, bufferSz3New, bufferSz2New, asyncToken,
-               asyncDependencies, desc, modeA, modeB, spmatA, spmatB, spmatC,
-               computeType, bufferSz3, buffer3, bufferSz2);}]>
-  ];
-
-  let assemblyFormat = [{
-    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $desc `,` $bufferSz3 `,` $bufferSz2 `,` $buffer3 attr-dict `:` $computeType `into` type($buffer3)
-  }];
-}
-
 def GPU_SpGEMMCopyOp : GPU_Op<"spgemm_copy", [GPU_AsyncOpInterface]> {
   let summary = "SpGEMM copy operation";
   let description = [{
-    The `gpu.spgemm_copy` operation copies a sparse matrix, e.g., the result of
-    the SpGEMM computation.
+    The `gpu.spgemm_copy` operation copies the sparse matrix result of
+    a SpGEMM computation.
 
     If the `async` keyword is present, the op is executed asynchronously (i.e.
     it does not block until the execution has finished on the device). In
@@ -2354,7 +2297,7 @@ def GPU_SpGEMMCopyOp : GPU_Op<"spgemm_copy", [GPU_AsyncOpInterface]> {
     Example:
 
     ```mlir
-    gpu.spgemm_copy %spmatA, %spmatB, %spmatC, ALG2, %spgemmDesc: f32
+    gpu.spgemm_copy %spmatA, %spmatB, %spmatC, %spgemmDesc: f32
     ```
 
     The matrix arguments can also be associated with one of the following
@@ -2422,4 +2365,37 @@ def GPU_SpGEMMGetSizeOp : GPU_Op<"spgemm_get_size", [GPU_AsyncOpInterface]> {
   }];
 }
 
+def GPU_SetCsrPointersOp : GPU_Op<"set_csr_pointers", [GPU_AsyncOpInterface]> {
+  let summary = "SpGEMM get size operation";
+  let description = [{
+    The `gpu.set_csr_pointers` assigns the given positions, coordinates,
+    and values buffer that reside on the device directly to the given sparse
+    matrix descriptor in csr format.
+
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a `!gpu.async.token` in addition to the environment.
+
+    Example:
+
+    ```mlir
+    %token = gpu.set_csr_pointers async [%dep] %positions, %coordinates, %values
+          : memref<?xf32>, memref<?xindex>, memref<?xindex>
+    ```
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                       Arg<GPU_SparseSpMatHandle>:$spmat,
+                       AnyMemRef:$positions,
+                       AnyMemRef:$coordinates,
+		       AnyMemRef:$values);
+  let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
+
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+      $spmat `,` $positions `,` $coordinates `,` $values attr-dict
+        `:` type($positions) `,` type($coordinates) `,` type($values)
+  }];
+}
+
 #endif // GPU_OPS

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 37efadd1be5625..88dd32a2146eb9 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -302,15 +302,6 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
        llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
        llvmPointerType /*void *stream*/}};
-  FunctionCallBuilder createSpGEMMEstimateMemoryBuilder = {
-      "mgpuSpGEMMEstimateMemory",
-      llvmVoidType,
-      {llvmPointerType /*nbs3*/, llvmPointerType /*nbs2*/,
-       llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
-       llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
-       llvmInt32Type /*ctp*/, llvmFloat32Type /*chunk_fraction*/,
-       llvmIntPtrType /*bs3*/, llvmPointerType /*buf3*/, llvmIntPtrType /*bs2*/,
-       llvmPointerType /*void *stream*/}};
   FunctionCallBuilder createSpGEMMComputeBuilder = {
       "mgpuSpGEMMCompute",
       llvmIntPtrType,
@@ -329,7 +320,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       llvmPointerType,
       {llvmPointerType /*void *stream*/}};
   FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
-      "mgpuSpGEMMDestoryDescr",
+      "mgpuSpGEMMDestroyDescr",
       llvmVoidType,
       {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
   FunctionCallBuilder createSpGEMMGetSizeBuilder = {
@@ -337,6 +328,12 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       llvmVoidType,
       {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
        llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
+  FunctionCallBuilder createSetCsrPointersBuilder = {
+      "mgpuSetCsrPointers",
+      llvmVoidType,
+      {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
+       llvmPointerType /*crd*/, llvmPointerType /*val*/,
+       llvmPointerType /*void *stream*/}};
 };
 
 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@@ -559,9 +556,9 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SDDMMOp)
 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp)
 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp)
 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
-DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMEstimateMemoryOp)
 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp)
 DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMGetSizeOp)
+DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
 
 } // namespace
 
@@ -1710,13 +1707,13 @@ LogicalResult
 ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
     gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
-
   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
   auto stream = adaptor.getAsyncDependencies().front();
-  createSpGEMMCopyBuilder.create(loc, rewriter, {adaptor.getDesc(), stream});
+  createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
+                                         {adaptor.getDesc(), stream});
   rewriter.replaceOp(op, {stream});
   return success();
 }
@@ -1764,55 +1761,6 @@ ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
-LogicalResult
-ConvertSpGEMMEstimateMemoryOpToGpuRuntimeCallPattern::matchAndRewrite(
-    gpu::SpGEMMEstimateMemoryOp op, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
-      failed(isAsyncWithOneDependency(rewriter, op)))
-    return failure();
-  Location loc = op.getLoc();
-  auto computeType = genConstInt32From(
-      rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
-  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
-  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
-  auto stream = adaptor.getAsyncDependencies().front();
-  // TODO: support other chunk fraction
-  Value chunkFraction = genConstFloat32From(rewriter, loc, 1.0);
-  Value pBuf3 =
-      MemRefDescriptor(adaptor.getBuffer3()).allocatedPtr(rewriter, loc);
-  if (!getTypeConverter()->useOpaquePointers())
-    pBuf3 = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf3);
-
-  auto two = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
-                                               rewriter.getIndexAttr(2));
-  auto bufferSize = rewriter.create<LLVM::AllocaOp>(
-      loc, llvmInt64PointerType, llvmInt64Type, two, /*alignment=*/16);
-
-  auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
-      loc, llvmInt64PointerType, llvmInt64PointerType, bufferSize,
-      ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
-                                                   rewriter.getIndexAttr(0))});
-  auto bufferSizePtr3 = rewriter.create<LLVM::GEPOp>(
-      loc, llvmInt64PointerType, llvmInt64PointerType, bufferSize,
-      ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
-                                                   rewriter.getIndexAttr(1))});
-
-  createSpGEMMEstimateMemoryBuilder.create(
-      loc, rewriter,
-      {bufferSizePtr3, bufferSizePtr2, adaptor.getDesc(), modeA, modeB,
-       adaptor.getSpmatA(), adaptor.getSpmatB(), adaptor.getSpmatC(),
-       computeType, chunkFraction, adaptor.getBufferSz3(), pBuf3,
-       adaptor.getBufferSz2(), stream});
-  auto bufferSize2 =
-      rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
-  auto bufferSize3 =
-      rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr3);
-
-  rewriter.replaceOp(op, {bufferSize3, bufferSize2, stream});
-  return success();
-}
-
 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
     gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -1869,6 +1817,31 @@ LogicalResult ConvertSpGEMMGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
+LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::SetCsrPointersOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
+      failed(isAsyncWithOneDependency(rewriter, op)))
+    return failure();
+  Location loc = op.getLoc();
+  auto stream = adaptor.getAsyncDependencies().front();
+  Value pPos =
+      MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
+  Value pCrd =
+      MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
+  Value pVal =
+      MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
+  if (!getTypeConverter()->useOpaquePointers()) {
+    pPos = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pPos);
+    pCrd = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pCrd);
+    pVal = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVal);
+  }
+  createSetCsrPointersBuilder.create(
+      loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
+  rewriter.replaceOp(op, {stream});
+  return success();
+}
+
 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns,
                                                StringRef gpuBinaryAnnotation,
@@ -1904,9 +1877,9 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
                ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
                ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
-               ConvertSpGEMMEstimateMemoryOpToGpuRuntimeCallPattern,
                ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
-               ConvertSpGEMMGetSizeOpToGpuRuntimeCallPattern>(converter);
+               ConvertSpGEMMGetSizeOpToGpuRuntimeCallPattern,
+               ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
   patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
       converter, gpuBinaryAnnotation, kernelBarePtrCallConv);
   patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 23c5e40e438189..fd338a14c504ef 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -615,45 +615,12 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMWorkEstimation(
   auto cTp = static_cast<cudaDataType_t>(ctp);
   ALPHABETA(cTp, alpha, beta)
   size_t newBufferSize = bs;
-
   CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_workEstimation(
       cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
       CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize, buf))
   return newBufferSize == 0 ? 1 : newBufferSize; // avoid zero-alloc
 }
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
-mgpuSpGEMMEstimateMemory(void *nbs3, void *nbs2, void *s, int32_t ma,
-                         int32_t mb, void *a, void *b, void *c, int32_t ctp,
-                         float chunk_fraction, intptr_t bs3, void *buf3,
-                         intptr_t bs2, CUstream /*stream*/) {
-  cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
-  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
-  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
-  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
-  cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
-  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
-  auto cTp = static_cast<cudaDataType_t>(ctp);
-  ALPHABETA(cTp, alpha, beta)
-  size_t *newBufferSize2 = reinterpret_cast<size_t *>(nbs2);
-  size_t *newBufferSize3 = reinterpret_cast<size_t *>(nbs3);
-  *newBufferSize2 = bs2;
-  *newBufferSize3 = bs3;
-
-  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_estimateMemory(
-      cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
-      CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, chunk_fraction, newBufferSize3, buf3,
-      newBufferSize2))
-  // avoid zero-alloc
-  if (*newBufferSize2 == 0) {
-    *newBufferSize2 = 1;
-  }
-  if (*newBufferSize3 == 0) {
-    *newBufferSize3 = 1;
-  }
-  return;
-}
-
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
 mgpuSpGEMMCompute(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
                   int32_t ctp, intptr_t bsz2, void *buf2, CUstream /*stream*/) {
@@ -683,7 +650,6 @@ mgpuSpGEMMCopy(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
   auto cTp = static_cast<cudaDataType_t>(ctp);
   ALPHABETA(cTp, alpha, beta)
-
   CUSPARSE_REPORT_IF_ERROR(
       cusparseSpGEMM_copy(cusparse_env, modeA, modeB, alphap, matA, matB, betap,
                           matC, cTp, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc))
@@ -712,6 +678,12 @@ mgpuSpGEMMGetSize(void *m, void *r, void *c, void *n, CUstream /*stream*/) {
   CUSPARSE_REPORT_IF_ERROR(cusparseSpMatGetSize(matDescr, rows, cols, nnz));
 }
 
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuSetCsrPointers(void *m, void *p, void *c, void *v, CUstream /*stream*/) {
+  cusparseSpMatDescr_t matDescr = reinterpret_cast<cusparseSpMatDescr_t>(m);
+  CUSPARSE_REPORT_IF_ERROR(cusparseCsrSetPointers(matDescr, p, c, v));
+}
+
 #ifdef MLIR_ENABLE_CUDA_CUSPARSELT
 
 ///

diff  --git a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
index 40489295143862..45b991bd4f8896 100644
--- a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
@@ -93,9 +93,6 @@ module attributes {gpu.container_module} {
   // CHECK: llvm.call @mgpuSpGEMMWorkEstimation
   // CHECK: llvm.call @mgpuMemAlloc
   // CHECK: llvm.call @mgpuSpGEMMWorkEstimation
-  // CHECK: llvm.call @mgpuSpGEMMEstimateMemory
-  // CHECK: llvm.call @mgpuMemAlloc
-  // CHECK: llvm.call @mgpuSpGEMMEstimateMemory
   // CHECK: llvm.call @mgpuMemAlloc
   // CHECK: llvm.call @mgpuSpGEMMCompute
   // CHECK: llvm.call @mgpuMemAlloc
@@ -130,19 +127,10 @@ module attributes {gpu.container_module} {
                               [%token8]{WORK_ESTIMATION} %spmatA, %spmatB,
                               %spmatC, %spgemmDesc, %bufferSz1,
                               %buf1: f32 into memref<?xi8>
-    %bufferSz3, %dummy, %token10 = gpu.spgemm_estimate_memory async [%token9]
-                                     %spmatA, %spmatB, %spmatC,
-                                     %spgemmDesc, %c0, %c0,
-                                     %alloc: f32 into memref<0xi8>
-    %buf3, %token11 = gpu.alloc async [%token10] (%bufferSz3) : memref<?xi8>
-    %bufferSz3_2, %bufferSz2, %token12 = gpu.spgemm_estimate_memory async
-                                          [%token11] %spmatA, %spmatB, %spmatC,
-                                          %spgemmDesc, %bufferSz3, %c0,
-                                          %buf3: f32 into memref<?xi8>
-    %buf2, %token13 = gpu.alloc async [%token12] (%bufferSz2) : memref<?xi8>
+    %buf2, %token13 = gpu.alloc async [%token9] (%bufferSz1_1) : memref<?xi8>
     %bufferSz2_2, %token14 = gpu.spgemm_work_estimation_or_compute async
                                [%token13]{COMPUTE} %spmatA, %spmatB, %spmatC,
-                               %spgemmDesc, %bufferSz2,
+                               %spgemmDesc, %bufferSz1_1,
                                %buf2: f32 into memref<?xi8>
     %rows, %cols, %nnz, %token15 = gpu.spgemm_get_size async [%token14] %spmatC
     %mem_columns, %token16 = gpu.alloc async [%token15] (%cols) : memref<?xi32>

diff  --git a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
index b39bc101af17f9..171a1ad24898ff 100644
--- a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
+++ b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
@@ -55,72 +55,45 @@ module attributes {gpu.container_module} {
   }
 
   // CHECK-LABEL:     func @spgemm
-  // CHECK:      %{{.*}} = gpu.wait async
+  // CHECK:           %{{.*}} = gpu.wait async
   // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xindex>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf64>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf32>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf32>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf32>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf32>
   // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_create_descr async [%{{.*}}]
-  // CHECK:           %{{.*}} = memref.alloc() : memref<0xi8>
-  // CHECK:           %{{.*}} = arith.constant 0 : index
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_estimate_memory async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_estimate_memory async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{ WORK_ESTIMATION} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{ COMPUTE} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
   // CHECK:           %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_get_size async [%{{.*}}] %{{.*}}
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi32>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf32>
-  // CHECK:           gpu.wait [%{{.*}}]
-  // CHECK:           gpu.spgemm_copy  %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32
-  // CHECK:           gpu.destroy_sp_mat  %{{.*}}
-  // CHECK:           gpu.destroy_sp_mat  %{{.*}}
-  // CHECK:           gpu.destroy_sp_mat  %{{.*}}
+  // CHECK            %{{.*}} = gpu.set_csr_pointers async [%{{.*}}] %{{.*}}, {{.*}}, {{.*}}, {{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf32>
+  // CHECK:           %{{.*}} = gpu.spgemm_copy async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32
+  // CHECK:           %{{.*}} = gpu.spgemm_destroy_descr async [%{{.*}}] %{{.*}}
+  // CHECK:           gpu.destroy_sp_mat %{{.*}}
+  // CHECK:           gpu.destroy_sp_mat %{{.*}}
+  // CHECK:           gpu.destroy_sp_mat %{{.*}}
   // CHECK:           return
   func.func @spgemm(%arg0: index) {
     %token0 = gpu.wait async
     %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
-    %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
-    %spmatA, %token3 = gpu.create_csr async [%token2] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
-    %spmatB, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
-    %spmatC, %token5 = gpu.create_csr async [%token4] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf32>
+    %spmatA, %token3 = gpu.create_csr async [%token2] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf32>
+    %spmatB, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf32>
+    %spmatC, %token5 = gpu.create_csr async [%token4] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf32>
     %spgemmDesc, %token6 = gpu.spgemm_create_descr async [%token5]
-    // Used as nullptr
-    %alloc = memref.alloc() : memref<0xi8>
+    %alloc = memref.alloc() : memref<0xi8>  // nullptr
     %c0 = arith.constant 0 : index
     %bufferSz1, %token7 = gpu.spgemm_work_estimation_or_compute async
                             [%token6]{WORK_ESTIMATION}
-                            %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE},
-                            %spmatC, %spgemmDesc, %c0,
-                            %alloc: f32 into memref<0xi8>
-    %buf1, %token8 = gpu.alloc async [%token7] (%bufferSz1) : memref<?xi8>
-    %bufferSz1_1, %token9 = gpu.spgemm_work_estimation_or_compute async
-                              [%token8]{WORK_ESTIMATION} %spmatA, %spmatB,
-                              %spmatC, %spgemmDesc, %bufferSz1,
-                              %buf1: f32 into memref<?xi8>
-    %bufferSz3, %dummy, %token10 = gpu.spgemm_estimate_memory async [%token9]
-                                     %spmatA, %spmatB, %spmatC,
-                                     %spgemmDesc, %c0, %c0,
-                                     %alloc: f32 into memref<0xi8>
-    %buf3, %token11 = gpu.alloc async [%token10] (%bufferSz3) : memref<?xi8>
-    %bufferSz3_2, %bufferSz2, %token12 = gpu.spgemm_estimate_memory async
-                                          [%token11] %spmatA, %spmatB, %spmatC,
-                                          %spgemmDesc, %bufferSz3, %c0,
-                                          %buf3: f32 into memref<?xi8>
-    %buf2, %token13 = gpu.alloc async [%token12] (%bufferSz2) : memref<?xi8>
-    %bufferSz2_2, %token14 = gpu.spgemm_work_estimation_or_compute async
-                               [%token13]{COMPUTE} %spmatA, %spmatB, %spmatC,
-                               %spgemmDesc, %bufferSz2,
-                               %buf2: f32 into memref<?xi8>
-    %rows, %cols, %nnz, %token15 = gpu.spgemm_get_size async [%token14] %spmatC
-    %mem_columns, %token16 = gpu.alloc async [%token15] (%cols) : memref<?xi32>
-    %mem_values, %token17 = gpu.alloc async [%token16] (%nnz) : memref<?xf32>
-    gpu.wait [%token17]
-    gpu.spgemm_copy %spmatA, %spmatB, %spmatC, %spgemmDesc: f32
+                            %spmatA, %spmatB, %spmatC,
+			    %spgemmDesc, %c0, %alloc: f32 into memref<0xi8>
+    %bufferSz2, %token8 = gpu.spgemm_work_estimation_or_compute async
+                               [%token7]{COMPUTE}
+			       %spmatA, %spmatB, %spmatC,
+                               %spgemmDesc, %c0, %alloc: f32 into memref<0xi8>
+    %rows, %cols, %nnz, %token9 = gpu.spgemm_get_size async [%token8] %spmatC
+    %token10 = gpu.set_csr_pointers async [%token8] %spmatC, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf32>
+    %token11 = gpu.spgemm_copy async [%token10] %spmatA, %spmatB, %spmatC, %spgemmDesc: f32
+    %token12 = gpu.spgemm_destroy_descr async [%token11] %spgemmDesc
     gpu.destroy_sp_mat %spmatA
     gpu.destroy_sp_mat %spmatB
     gpu.destroy_sp_mat %spmatC


        


More information about the Mlir-commits mailing list