[Mlir-commits] [mlir] 95a6c50 - [mlir][sparse][gpu] add set csr pointers, remove estimate op, fix bugs

Aart Bik llvmlistbot at llvm.org
Thu Aug 10 13:52:58 PDT 2023

Author: Aart Bik
Date: 2023-08-10T13:52:47-07:00
New Revision: 95a6c509c9ec9b87ca951de4cded9a7807058ae6

URL: https://github.com/llvm/llvm-project/commit/95a6c509c9ec9b87ca951de4cded9a7807058ae6
DIFF: https://github.com/llvm/llvm-project/commit/95a6c509c9ec9b87ca951de4cded9a7807058ae6.diff

LOG: [mlir][sparse][gpu] add set csr pointers, remove estimate op, fix bugs

Since we only support default algorithm for SpGEMM, we can remove the
estimate op (for now at least). This also introduces the set csr pointers
op, and fixes a few bugs in the existing lowering for the SpGEMM breakdown.
This revision paves the way for actual recognition of SpGEMM in the sparsifier.

Reviewed By: K-Wu

Differential Revision: https://reviews.llvm.org/D157645




diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 8dc88663c0c8c2..498e8c37049a8d 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -1561,6 +1561,7 @@ def GPU_SubgroupMmaElementwiseOp : GPU_Op<"subgroup_mma_elementwise",
 // Operation on sparse matrices, called from the host
 // (currently lowers to cuSparse for CUDA only, no ROCM lowering).
 def GPU_CreateDnTensorOp : GPU_Op<"create_dn_tensor", [GPU_AsyncOpInterface, AttrSizedOperandSegments]> {
   let summary = "Create dense tensor operation";
   let description = [{
@@ -2234,9 +2235,9 @@ def GPU_SpGEMMWorkEstimationOrComputeOp : GPU_Op<"spgemm_work_estimation_or_comp
-    %bufferSz, %token = gpu.spgemm_work_estimation_or_compute async [%dep]{COMPUTE}
+    %bufferSz, %token = gpu.spgemm_work_estimation_or_compute async [%dep] {COMPUTE}
                           %desc, %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE},
-                          %spmatC, ALG2, %spgemmDesc, %c0, %alloc: f32 into
+                          %spmatC, %spgemmDesc, %c0, %alloc: f32 into
@@ -2283,69 +2284,11 @@ def GPU_SpGEMMWorkEstimationOrComputeOp : GPU_Op<"spgemm_work_estimation_or_comp
-def GPU_SpGEMMEstimateMemoryOp : GPU_Op<"spgemm_estimate_memory", [GPU_AsyncOpInterface]> {
-  let summary = "SpGEMM estimate memory operation";
-  let description = [{
-    The `gpu.spgemm_estimate_memory` is used for both determining the buffer
-    size and performing the actual computation.
-    If the `async` keyword is present, the op is executed asynchronously (i.e.
-    it does not block until the execution has finished on the device). In
-    that case, it returns a `!gpu.async.token` in addition to the environment.
-    Example:
-    ```mlir
-    %bufferSz3, %dummy, %token = gpu.spgemm_estimate_memory async [%dep] %spmatA, %spmatB, %spmatC, ALG2, %spgemmDesc, %c0, %c0, %alloc: f32 into memref<0xi8>
-    ```
-  }];
-  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
-                       GPU_SparseSpGEMMOpHandle:$desc,
-                       GPU_TransposeModeAttr:$modeA,
-                       GPU_TransposeModeAttr:$modeB,
-                       GPU_SparseSpMatHandle:$spmatA,
-                       GPU_SparseSpMatHandle:$spmatB,
-                       GPU_SparseSpMatHandle:$spmatC,
-                       TypeAttr:$computeType,
-                       Index:$bufferSz3,
-                       AnyMemRef:$buffer3,
-                       Index:$bufferSz2);
-  let results = (outs Index:$bufferSz3New,
-                      Index:$bufferSz2New,
-                      Optional<GPU_AsyncToken>:$asyncToken);
-  let builders = [OpBuilder<(ins
-    "Type":$bufferSz3New,
-    "Type":$bufferSz2New,
-    "Type":$asyncToken,
-    "ValueRange":$asyncDependencies,
-    "Value":$desc,
-    "Value":$spmatA,
-    "Value":$spmatB,
-    "Value":$spmatC,
-    "Type":$computeType,
-    "Value":$bufferSz3,
-    "Value":$buffer3,
-    "Value":$bufferSz2), [{
-  auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
-  auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
-  return build($_builder, $_state, bufferSz3New, bufferSz2New, asyncToken,
-               asyncDependencies, desc, modeA, modeB, spmatA, spmatB, spmatC,
-               computeType, bufferSz3, buffer3, bufferSz2);}]>
-  ];
-  let assemblyFormat = [{
-    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $desc `,` $bufferSz3 `,` $bufferSz2 `,` $buffer3 attr-dict `:` $computeType `into` type($buffer3)
-  }];
 def GPU_SpGEMMCopyOp : GPU_Op<"spgemm_copy", [GPU_AsyncOpInterface]> {
   let summary = "SpGEMM copy operation";
   let description = [{
-    The `gpu.spgemm_copy` operation copies a sparse matrix, e.g., the result of
-    the SpGEMM computation.
+    The `gpu.spgemm_copy` operation copies the sparse matrix result of
+    a SpGEMM computation.
     If the `async` keyword is present, the op is executed asynchronously (i.e.
     it does not block until the execution has finished on the device). In
@@ -2354,7 +2297,7 @@ def GPU_SpGEMMCopyOp : GPU_Op<"spgemm_copy", [GPU_AsyncOpInterface]> {
-    gpu.spgemm_copy %spmatA, %spmatB, %spmatC, ALG2, %spgemmDesc: f32
+    gpu.spgemm_copy %spmatA, %spmatB, %spmatC, %spgemmDesc: f32
     The matrix arguments can also be associated with one of the following
@@ -2422,4 +2365,37 @@ def GPU_SpGEMMGetSizeOp : GPU_Op<"spgemm_get_size", [GPU_AsyncOpInterface]> {
+def GPU_SetCsrPointersOp : GPU_Op<"set_csr_pointers", [GPU_AsyncOpInterface]> {
+  let summary = "SpGEMM get size operation";
+  let description = [{
+    The `gpu.set_csr_pointers` assigns the given positions, coordinates,
+    and values buffer that reside on the device directly to the given sparse
+    matrix descriptor in csr format.
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a `!gpu.async.token` in addition to the environment.
+    Example:
+    ```mlir
+    %token = gpu.set_csr_pointers async [%dep] %positions, %coordinates, %values
+          : memref<?xf32>, memref<?xindex>, memref<?xindex>
+    ```
+  }];
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                       Arg<GPU_SparseSpMatHandle>:$spmat,
+                       AnyMemRef:$positions,
+                       AnyMemRef:$coordinates,
+		       AnyMemRef:$values);
+  let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+      $spmat `,` $positions `,` $coordinates `,` $values attr-dict
+        `:` type($positions) `,` type($coordinates) `,` type($values)
+  }];
 #endif // GPU_OPS

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 37efadd1be5625..88dd32a2146eb9 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -302,15 +302,6 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
        llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
        llvmPointerType /*void *stream*/}};
-  FunctionCallBuilder createSpGEMMEstimateMemoryBuilder = {
-      "mgpuSpGEMMEstimateMemory",
-      llvmVoidType,
-      {llvmPointerType /*nbs3*/, llvmPointerType /*nbs2*/,
-       llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
-       llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
-       llvmInt32Type /*ctp*/, llvmFloat32Type /*chunk_fraction*/,
-       llvmIntPtrType /*bs3*/, llvmPointerType /*buf3*/, llvmIntPtrType /*bs2*/,
-       llvmPointerType /*void *stream*/}};
   FunctionCallBuilder createSpGEMMComputeBuilder = {
@@ -329,7 +320,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       {llvmPointerType /*void *stream*/}};
   FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
-      "mgpuSpGEMMDestoryDescr",
+      "mgpuSpGEMMDestroyDescr",
       {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
   FunctionCallBuilder createSpGEMMGetSizeBuilder = {
@@ -337,6 +328,12 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
        llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
+  FunctionCallBuilder createSetCsrPointersBuilder = {
+      "mgpuSetCsrPointers",
+      llvmVoidType,
+      {llvmPointerType /*spmat*/, llvmPointerType /*pos*/,
+       llvmPointerType /*crd*/, llvmPointerType /*val*/,
+       llvmPointerType /*void *stream*/}};
 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
 } // namespace
@@ -1710,13 +1707,13 @@ LogicalResult
     gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
   auto stream = adaptor.getAsyncDependencies().front();
-  createSpGEMMCopyBuilder.create(loc, rewriter, {adaptor.getDesc(), stream});
+  createSpGEMMDestroyDescrBuilder.create(loc, rewriter,
+                                         {adaptor.getDesc(), stream});
   rewriter.replaceOp(op, {stream});
   return success();
@@ -1764,55 +1761,6 @@ ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
-    gpu::SpGEMMEstimateMemoryOp op, OpAdaptor adaptor,
-    ConversionPatternRewriter &rewriter) const {
-  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
-      failed(isAsyncWithOneDependency(rewriter, op)))
-    return failure();
-  Location loc = op.getLoc();
-  auto computeType = genConstInt32From(
-      rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
-  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
-  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
-  auto stream = adaptor.getAsyncDependencies().front();
-  // TODO: support other chunk fraction
-  Value chunkFraction = genConstFloat32From(rewriter, loc, 1.0);
-  Value pBuf3 =
-      MemRefDescriptor(adaptor.getBuffer3()).allocatedPtr(rewriter, loc);
-  if (!getTypeConverter()->useOpaquePointers())
-    pBuf3 = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf3);
-  auto two = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
-                                               rewriter.getIndexAttr(2));
-  auto bufferSize = rewriter.create<LLVM::AllocaOp>(
-      loc, llvmInt64PointerType, llvmInt64Type, two, /*alignment=*/16);
-  auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
-      loc, llvmInt64PointerType, llvmInt64PointerType, bufferSize,
-      ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
-                                                   rewriter.getIndexAttr(0))});
-  auto bufferSizePtr3 = rewriter.create<LLVM::GEPOp>(
-      loc, llvmInt64PointerType, llvmInt64PointerType, bufferSize,
-      ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
-                                                   rewriter.getIndexAttr(1))});
-  createSpGEMMEstimateMemoryBuilder.create(
-      loc, rewriter,
-      {bufferSizePtr3, bufferSizePtr2, adaptor.getDesc(), modeA, modeB,
-       adaptor.getSpmatA(), adaptor.getSpmatB(), adaptor.getSpmatC(),
-       computeType, chunkFraction, adaptor.getBufferSz3(), pBuf3,
-       adaptor.getBufferSz2(), stream});
-  auto bufferSize2 =
-      rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
-  auto bufferSize3 =
-      rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr3);
-  rewriter.replaceOp(op, {bufferSize3, bufferSize2, stream});
-  return success();
 LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
     gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -1869,6 +1817,31 @@ LogicalResult ConvertSpGEMMGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
+LogicalResult ConvertSetCsrPointersOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::SetCsrPointersOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
+      failed(isAsyncWithOneDependency(rewriter, op)))
+    return failure();
+  Location loc = op.getLoc();
+  auto stream = adaptor.getAsyncDependencies().front();
+  Value pPos =
+      MemRefDescriptor(adaptor.getPositions()).allocatedPtr(rewriter, loc);
+  Value pCrd =
+      MemRefDescriptor(adaptor.getCoordinates()).allocatedPtr(rewriter, loc);
+  Value pVal =
+      MemRefDescriptor(adaptor.getValues()).allocatedPtr(rewriter, loc);
+  if (!getTypeConverter()->useOpaquePointers()) {
+    pPos = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pPos);
+    pCrd = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pCrd);
+    pVal = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVal);
+  }
+  createSetCsrPointersBuilder.create(
+      loc, rewriter, {adaptor.getSpmat(), pPos, pCrd, pVal, stream});
+  rewriter.replaceOp(op, {stream});
+  return success();
 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns,
                                                StringRef gpuBinaryAnnotation,
@@ -1904,9 +1877,9 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
-               ConvertSpGEMMEstimateMemoryOpToGpuRuntimeCallPattern,
-               ConvertSpGEMMGetSizeOpToGpuRuntimeCallPattern>(converter);
+               ConvertSpGEMMGetSizeOpToGpuRuntimeCallPattern,
+               ConvertSetCsrPointersOpToGpuRuntimeCallPattern>(converter);
       converter, gpuBinaryAnnotation, kernelBarePtrCallConv);

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 23c5e40e438189..fd338a14c504ef 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -615,45 +615,12 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMWorkEstimation(
   auto cTp = static_cast<cudaDataType_t>(ctp);
   ALPHABETA(cTp, alpha, beta)
   size_t newBufferSize = bs;
       cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
       CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize, buf))
   return newBufferSize == 0 ? 1 : newBufferSize; // avoid zero-alloc
-mgpuSpGEMMEstimateMemory(void *nbs3, void *nbs2, void *s, int32_t ma,
-                         int32_t mb, void *a, void *b, void *c, int32_t ctp,
-                         float chunk_fraction, intptr_t bs3, void *buf3,
-                         intptr_t bs2, CUstream /*stream*/) {
-  cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
-  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
-  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
-  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
-  cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
-  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
-  auto cTp = static_cast<cudaDataType_t>(ctp);
-  ALPHABETA(cTp, alpha, beta)
-  size_t *newBufferSize2 = reinterpret_cast<size_t *>(nbs2);
-  size_t *newBufferSize3 = reinterpret_cast<size_t *>(nbs3);
-  *newBufferSize2 = bs2;
-  *newBufferSize3 = bs3;
-  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_estimateMemory(
-      cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
-      CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, chunk_fraction, newBufferSize3, buf3,
-      newBufferSize2))
-  // avoid zero-alloc
-  if (*newBufferSize2 == 0) {
-    *newBufferSize2 = 1;
-  }
-  if (*newBufferSize3 == 0) {
-    *newBufferSize3 = 1;
-  }
-  return;
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
 mgpuSpGEMMCompute(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
                   int32_t ctp, intptr_t bsz2, void *buf2, CUstream /*stream*/) {
@@ -683,7 +650,6 @@ mgpuSpGEMMCopy(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
   auto cTp = static_cast<cudaDataType_t>(ctp);
   ALPHABETA(cTp, alpha, beta)
       cusparseSpGEMM_copy(cusparse_env, modeA, modeB, alphap, matA, matB, betap,
                           matC, cTp, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc))
@@ -712,6 +678,12 @@ mgpuSpGEMMGetSize(void *m, void *r, void *c, void *n, CUstream /*stream*/) {
   CUSPARSE_REPORT_IF_ERROR(cusparseSpMatGetSize(matDescr, rows, cols, nnz));
+mgpuSetCsrPointers(void *m, void *p, void *c, void *v, CUstream /*stream*/) {
+  cusparseSpMatDescr_t matDescr = reinterpret_cast<cusparseSpMatDescr_t>(m);
+  CUSPARSE_REPORT_IF_ERROR(cusparseCsrSetPointers(matDescr, p, c, v));

diff  --git a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
index 40489295143862..45b991bd4f8896 100644
--- a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
@@ -93,9 +93,6 @@ module attributes {gpu.container_module} {
   // CHECK: llvm.call @mgpuSpGEMMWorkEstimation
   // CHECK: llvm.call @mgpuMemAlloc
   // CHECK: llvm.call @mgpuSpGEMMWorkEstimation
-  // CHECK: llvm.call @mgpuSpGEMMEstimateMemory
-  // CHECK: llvm.call @mgpuMemAlloc
-  // CHECK: llvm.call @mgpuSpGEMMEstimateMemory
   // CHECK: llvm.call @mgpuMemAlloc
   // CHECK: llvm.call @mgpuSpGEMMCompute
   // CHECK: llvm.call @mgpuMemAlloc
@@ -130,19 +127,10 @@ module attributes {gpu.container_module} {
                               [%token8]{WORK_ESTIMATION} %spmatA, %spmatB,
                               %spmatC, %spgemmDesc, %bufferSz1,
                               %buf1: f32 into memref<?xi8>
-    %bufferSz3, %dummy, %token10 = gpu.spgemm_estimate_memory async [%token9]
-                                     %spmatA, %spmatB, %spmatC,
-                                     %spgemmDesc, %c0, %c0,
-                                     %alloc: f32 into memref<0xi8>
-    %buf3, %token11 = gpu.alloc async [%token10] (%bufferSz3) : memref<?xi8>
-    %bufferSz3_2, %bufferSz2, %token12 = gpu.spgemm_estimate_memory async
-                                          [%token11] %spmatA, %spmatB, %spmatC,
-                                          %spgemmDesc, %bufferSz3, %c0,
-                                          %buf3: f32 into memref<?xi8>
-    %buf2, %token13 = gpu.alloc async [%token12] (%bufferSz2) : memref<?xi8>
+    %buf2, %token13 = gpu.alloc async [%token9] (%bufferSz1_1) : memref<?xi8>
     %bufferSz2_2, %token14 = gpu.spgemm_work_estimation_or_compute async
                                [%token13]{COMPUTE} %spmatA, %spmatB, %spmatC,
-                               %spgemmDesc, %bufferSz2,
+                               %spgemmDesc, %bufferSz1_1,
                                %buf2: f32 into memref<?xi8>
     %rows, %cols, %nnz, %token15 = gpu.spgemm_get_size async [%token14] %spmatC
     %mem_columns, %token16 = gpu.alloc async [%token15] (%cols) : memref<?xi32>

diff  --git a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
index b39bc101af17f9..171a1ad24898ff 100644
--- a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
+++ b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
@@ -55,72 +55,45 @@ module attributes {gpu.container_module} {
   // CHECK-LABEL:     func @spgemm
-  // CHECK:      %{{.*}} = gpu.wait async
+  // CHECK:           %{{.*}} = gpu.wait async
   // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xindex>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf64>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf32>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf32>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf32>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf32>
   // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_create_descr async [%{{.*}}]
-  // CHECK:           %{{.*}} = memref.alloc() : memref<0xi8>
-  // CHECK:           %{{.*}} = arith.constant 0 : index
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_estimate_memory async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_estimate_memory async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{ WORK_ESTIMATION} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{ COMPUTE} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
   // CHECK:           %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_get_size async [%{{.*}}] %{{.*}}
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi32>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf32>
-  // CHECK:           gpu.wait [%{{.*}}]
-  // CHECK:           gpu.spgemm_copy  %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32
-  // CHECK:           gpu.destroy_sp_mat  %{{.*}}
-  // CHECK:           gpu.destroy_sp_mat  %{{.*}}
-  // CHECK:           gpu.destroy_sp_mat  %{{.*}}
+  // CHECK            %{{.*}} = gpu.set_csr_pointers async [%{{.*}}] %{{.*}}, {{.*}}, {{.*}}, {{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf32>
+  // CHECK:           %{{.*}} = gpu.spgemm_copy async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32
+  // CHECK:           %{{.*}} = gpu.spgemm_destroy_descr async [%{{.*}}] %{{.*}}
+  // CHECK:           gpu.destroy_sp_mat %{{.*}}
+  // CHECK:           gpu.destroy_sp_mat %{{.*}}
+  // CHECK:           gpu.destroy_sp_mat %{{.*}}
   // CHECK:           return
   func.func @spgemm(%arg0: index) {
     %token0 = gpu.wait async
     %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
-    %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
-    %spmatA, %token3 = gpu.create_csr async [%token2] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
-    %spmatB, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
-    %spmatC, %token5 = gpu.create_csr async [%token4] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf32>
+    %spmatA, %token3 = gpu.create_csr async [%token2] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf32>
+    %spmatB, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf32>
+    %spmatC, %token5 = gpu.create_csr async [%token4] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf32>
     %spgemmDesc, %token6 = gpu.spgemm_create_descr async [%token5]
-    // Used as nullptr
-    %alloc = memref.alloc() : memref<0xi8>
+    %alloc = memref.alloc() : memref<0xi8>  // nullptr
     %c0 = arith.constant 0 : index
     %bufferSz1, %token7 = gpu.spgemm_work_estimation_or_compute async
-                            %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE},
-                            %spmatC, %spgemmDesc, %c0,
-                            %alloc: f32 into memref<0xi8>
-    %buf1, %token8 = gpu.alloc async [%token7] (%bufferSz1) : memref<?xi8>
-    %bufferSz1_1, %token9 = gpu.spgemm_work_estimation_or_compute async
-                              [%token8]{WORK_ESTIMATION} %spmatA, %spmatB,
-                              %spmatC, %spgemmDesc, %bufferSz1,
-                              %buf1: f32 into memref<?xi8>
-    %bufferSz3, %dummy, %token10 = gpu.spgemm_estimate_memory async [%token9]
-                                     %spmatA, %spmatB, %spmatC,
-                                     %spgemmDesc, %c0, %c0,
-                                     %alloc: f32 into memref<0xi8>
-    %buf3, %token11 = gpu.alloc async [%token10] (%bufferSz3) : memref<?xi8>
-    %bufferSz3_2, %bufferSz2, %token12 = gpu.spgemm_estimate_memory async
-                                          [%token11] %spmatA, %spmatB, %spmatC,
-                                          %spgemmDesc, %bufferSz3, %c0,
-                                          %buf3: f32 into memref<?xi8>
-    %buf2, %token13 = gpu.alloc async [%token12] (%bufferSz2) : memref<?xi8>
-    %bufferSz2_2, %token14 = gpu.spgemm_work_estimation_or_compute async
-                               [%token13]{COMPUTE} %spmatA, %spmatB, %spmatC,
-                               %spgemmDesc, %bufferSz2,
-                               %buf2: f32 into memref<?xi8>
-    %rows, %cols, %nnz, %token15 = gpu.spgemm_get_size async [%token14] %spmatC
-    %mem_columns, %token16 = gpu.alloc async [%token15] (%cols) : memref<?xi32>
-    %mem_values, %token17 = gpu.alloc async [%token16] (%nnz) : memref<?xf32>
-    gpu.wait [%token17]
-    gpu.spgemm_copy %spmatA, %spmatB, %spmatC, %spgemmDesc: f32
+                            %spmatA, %spmatB, %spmatC,
+			    %spgemmDesc, %c0, %alloc: f32 into memref<0xi8>
+    %bufferSz2, %token8 = gpu.spgemm_work_estimation_or_compute async
+                               [%token7]{COMPUTE}
+			       %spmatA, %spmatB, %spmatC,
+                               %spgemmDesc, %c0, %alloc: f32 into memref<0xi8>
+    %rows, %cols, %nnz, %token9 = gpu.spgemm_get_size async [%token8] %spmatC
+    %token10 = gpu.set_csr_pointers async [%token8] %spmatC, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf32>
+    %token11 = gpu.spgemm_copy async [%token10] %spmatA, %spmatB, %spmatC, %spgemmDesc: f32
+    %token12 = gpu.spgemm_destroy_descr async [%token11] %spgemmDesc
     gpu.destroy_sp_mat %spmatA
     gpu.destroy_sp_mat %spmatB
     gpu.destroy_sp_mat %spmatC


More information about the Mlir-commits mailing list