[Mlir-commits] [mlir] e7e4ed0 - [mlir][sparse][gpu] only support default algorithm for SpGEMM

Aart Bik llvmlistbot at llvm.org
Wed Aug 9 12:49:56 PDT 2023


Author: Aart Bik
Date: 2023-08-09T12:49:47-07:00
New Revision: e7e4ed0d7a28b6d7d7b7211b42c02d72e930dec1

URL: https://github.com/llvm/llvm-project/commit/e7e4ed0d7a28b6d7d7b7211b42c02d72e930dec1
DIFF: https://github.com/llvm/llvm-project/commit/e7e4ed0d7a28b6d7d7b7211b42c02d72e930dec1.diff

LOG: [mlir][sparse][gpu] only support default algorithm for SpGEMM

Rationale:
This is the approach taken for all the others too (SpMV, SpMM, SDDMM),
so it is more consistent to follow the same path (until we have a need
for more algorithms). Also, in a follow up revision, this will allow
us to remove some unused GEMM ops.

Reviewed By: K-Wu

Differential Revision: https://reviews.llvm.org/D157542

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
    mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
    mlir/test/Dialect/GPU/sparse-roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index e65d05f92ef8ed..8dc88663c0c8c2 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -2149,23 +2149,6 @@ def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
   }];
 }
 
-// ALG1, ALG2, ALG3 use 3--5 to align with cusparseSpGEMMAlg_t in cusparse.h.
-def GPU_SpGEMMAlg : I32EnumAttr<"SpGEMMAlg",
-    "selected algorithm for sparse matrix SpGEMM",
-    [
-      I32EnumAttrCase<"ALG1", 3>,
-      I32EnumAttrCase<"ALG2", 4>,
-      I32EnumAttrCase<"ALG3", 5>,
-    ]> {
-      let genSpecializedAttr = 0;
-      let cppNamespace = GPU_Dialect.cppNamespace;
-      let defaultValue = "SpGEMMAlg::ALG1";
-}
-
-def GPU_SpGEMMAlgAttr : EnumAttr<GPU_Dialect, GPU_SpGEMMAlg, "spgemm_alg"> {
-  let defaultValue = GPU_SpGEMMAlg.defaultValue;
-}
-
 def GPU_SpGEMMWorkEstimationOrComputeKind : I32EnumAttr<"SpGEMMWorkEstimationOrComputeKind",
     "choose whether spgemm_work_estimation_or_compute does work estimation or compute",
     [
@@ -2195,9 +2178,8 @@ def GPU_SpGEMMCreateDescrOp : GPU_Op<"spgemm_create_descr", [GPU_AsyncOpInterfac
     Example:
 
     ```mlir
-    %desc,  %token = gpu.spgemm_create_descr async [%dep]
+    %desc, %token = gpu.spgemm_create_descr async [%dep]
     ```
-
   }];
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies);
   let results = (outs GPU_SparseSpGEMMOpHandle:$desc,
@@ -2222,7 +2204,6 @@ def GPU_SpGEMMDestroyDescrOp : GPU_Op<"spgemm_destroy_descr", [GPU_AsyncOpInterf
     ```mlir
     %token = gpu.spgemm_destroy_descr async [%dep] %desc
     ```
-
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
@@ -2234,7 +2215,6 @@ def GPU_SpGEMMDestroyDescrOp : GPU_Op<"spgemm_destroy_descr", [GPU_AsyncOpInterf
   }];
 }
 
-
 def GPU_SpGEMMWorkEstimationOrComputeOp : GPU_Op<"spgemm_work_estimation_or_compute", [GPU_AsyncOpInterface]> {
   let summary = "SpGEMM work estimation operation";
   let description = [{
@@ -2245,7 +2225,6 @@ def GPU_SpGEMMWorkEstimationOrComputeOp : GPU_Op<"spgemm_work_estimation_or_comp
     construct an environment and the operands for SpGEMM.
     The buffer must have been allocated on the device.
 
-
     C' = alpha * op(A) * op(B) + beta * C
 
     If the `async` keyword is present, the op is executed asynchronously (i.e.
@@ -2264,7 +2243,6 @@ def GPU_SpGEMMWorkEstimationOrComputeOp : GPU_Op<"spgemm_work_estimation_or_comp
     The matrix arguments can also be associated with one of the following
     operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
     is NON_TRANSPOSE.
-
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
@@ -2276,7 +2254,6 @@ def GPU_SpGEMMWorkEstimationOrComputeOp : GPU_Op<"spgemm_work_estimation_or_comp
                        GPU_SparseSpMatHandle:$spmatC,
                        TypeAttr:$computeType,
                        Index:$bufferSz,
-                       GPU_SpGEMMAlgAttr:$alg,
                        AnyMemRef:$buffer,
                        GPU_SpGEMMWorkEstimationOrComputeKindAttr:$kind);
   let results = (outs Res<Index>:$bufferSzNew,
@@ -2295,19 +2272,17 @@ def GPU_SpGEMMWorkEstimationOrComputeOp : GPU_Op<"spgemm_work_estimation_or_comp
     "Value":$buffer), [{
   auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
   auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
-  auto alg = gpu::SpGEMMAlg::ALG1;
   auto kind = gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION;
   return build($_builder, $_state, bufferSzNew, asyncToken, asyncDependencies, desc,
-               modeA, modeB, spmatA, spmatB, spmatC, computeType, bufferSz, alg, buffer, kind);}]>
+               modeA, modeB, spmatA, spmatB, spmatC, computeType, bufferSz, buffer, kind);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    `{` $kind `}` $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $alg `,` $desc `,` $bufferSz `,` $buffer  attr-dict `:` $computeType `into` type($buffer)
+    `{` $kind `}` $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $desc `,` $bufferSz `,` $buffer  attr-dict `:` $computeType `into` type($buffer)
   }];
 }
 
-
 def GPU_SpGEMMEstimateMemoryOp : GPU_Op<"spgemm_estimate_memory", [GPU_AsyncOpInterface]> {
   let summary = "SpGEMM estimate memory operation";
   let description = [{
@@ -2323,7 +2298,6 @@ def GPU_SpGEMMEstimateMemoryOp : GPU_Op<"spgemm_estimate_memory", [GPU_AsyncOpIn
     ```mlir
     %bufferSz3, %dummy, %token = gpu.spgemm_estimate_memory async [%dep] %spmatA, %spmatB, %spmatC, ALG2, %spgemmDesc, %c0, %c0, %alloc: f32 into memref<0xi8>
     ```
-
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
@@ -2334,7 +2308,6 @@ def GPU_SpGEMMEstimateMemoryOp : GPU_Op<"spgemm_estimate_memory", [GPU_AsyncOpIn
                        GPU_SparseSpMatHandle:$spmatB,
                        GPU_SparseSpMatHandle:$spmatC,
                        TypeAttr:$computeType,
-                       GPU_SpGEMMAlgAttr:$alg,
                        Index:$bufferSz3,
                        AnyMemRef:$buffer3,
                        Index:$bufferSz2);
@@ -2357,19 +2330,17 @@ def GPU_SpGEMMEstimateMemoryOp : GPU_Op<"spgemm_estimate_memory", [GPU_AsyncOpIn
     "Value":$bufferSz2), [{
   auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
   auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
-  auto alg = gpu::SpGEMMAlg::ALG1;
   return build($_builder, $_state, bufferSz3New, bufferSz2New, asyncToken,
                asyncDependencies, desc, modeA, modeB, spmatA, spmatB, spmatC,
-               computeType, alg, bufferSz3, buffer3, bufferSz2);}]>
+               computeType, bufferSz3, buffer3, bufferSz2);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $alg `,` $desc `,` $bufferSz3 `,` $bufferSz2 `,` $buffer3 attr-dict `:` $computeType `into` type($buffer3)
+    $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $desc `,` $bufferSz3 `,` $bufferSz2 `,` $buffer3 attr-dict `:` $computeType `into` type($buffer3)
   }];
 }
 
-
 def GPU_SpGEMMCopyOp : GPU_Op<"spgemm_copy", [GPU_AsyncOpInterface]> {
   let summary = "SpGEMM copy operation";
   let description = [{
@@ -2389,7 +2360,6 @@ def GPU_SpGEMMCopyOp : GPU_Op<"spgemm_copy", [GPU_AsyncOpInterface]> {
     The matrix arguments can also be associated with one of the following
     operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
     is NON_TRANSPOSE.
-
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
@@ -2399,8 +2369,7 @@ def GPU_SpGEMMCopyOp : GPU_Op<"spgemm_copy", [GPU_AsyncOpInterface]> {
                        GPU_SparseSpMatHandle:$spmatA,
                        GPU_SparseSpMatHandle:$spmatB,
                        GPU_SparseSpMatHandle:$spmatC,
-                       TypeAttr:$computeType,
-                       GPU_SpGEMMAlgAttr:$alg);
+                       TypeAttr:$computeType);
   let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
 
   let builders = [OpBuilder<(ins
@@ -2413,18 +2382,16 @@ def GPU_SpGEMMCopyOp : GPU_Op<"spgemm_copy", [GPU_AsyncOpInterface]> {
     "Type":$computeType), [{
   auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
   auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
-  auto alg = gpu::SpGEMMAlg::ALG1;
   return build($_builder, $_state, asyncToken, asyncDependencies, desc,
-               modeA, modeB, spmatA, spmatB, spmatC, computeType, alg);}]>
+               modeA, modeB, spmatA, spmatB, spmatC, computeType);}]>
   ];
 
   let assemblyFormat = [{
     custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
-    $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $alg `,` $desc attr-dict `:` $computeType
+    $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $desc attr-dict `:` $computeType
   }];
 }
 
-
 def GPU_SpGEMMGetSizeOp : GPU_Op<"spgemm_get_size", [GPU_AsyncOpInterface]> {
   let summary = "SpGEMM get size operation";
   let description = [{
@@ -2440,11 +2407,6 @@ def GPU_SpGEMMGetSizeOp : GPU_Op<"spgemm_get_size", [GPU_AsyncOpInterface]> {
     ```mlir
     %rows, %cols, %nnz, %token = gpu.spgemm_get_size async [%dep] %spmatC
     ```
-
-    The matrix arguments can also be associated with one of the following
-    operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
-    is NON_TRANSPOSE.
-
   }];
 
   let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 7428b5ebe521d9..37efadd1be5625 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -300,32 +300,30 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       llvmIntPtrType,
       {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
-       llvmInt32Type /*ctp*/, llvmInt32Type /*alg*/, llvmIntPtrType /*bs*/,
-       llvmPointerType /*buf*/, llvmPointerType /*void *stream*/}};
+       llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
+       llvmPointerType /*void *stream*/}};
   FunctionCallBuilder createSpGEMMEstimateMemoryBuilder = {
       "mgpuSpGEMMEstimateMemory",
       llvmVoidType,
       {llvmPointerType /*nbs3*/, llvmPointerType /*nbs2*/,
        llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
-       llvmInt32Type /*ctp*/, llvmInt32Type /*alg*/,
-       llvmFloat32Type /*chunk_fraction*/, llvmIntPtrType /*bs3*/,
-       llvmPointerType /*buf3*/, llvmIntPtrType /*bs2*/,
+       llvmInt32Type /*ctp*/, llvmFloat32Type /*chunk_fraction*/,
+       llvmIntPtrType /*bs3*/, llvmPointerType /*buf3*/, llvmIntPtrType /*bs2*/,
        llvmPointerType /*void *stream*/}};
   FunctionCallBuilder createSpGEMMComputeBuilder = {
       "mgpuSpGEMMCompute",
       llvmIntPtrType,
       {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
-       llvmInt32Type /*ctp*/, llvmInt32Type /*alg*/, llvmIntPtrType /*bs*/,
-       llvmPointerType /*buf*/, llvmPointerType /*void *stream*/}};
+       llvmInt32Type /*ctp*/, llvmIntPtrType /*bs*/, llvmPointerType /*buf*/,
+       llvmPointerType /*void *stream*/}};
   FunctionCallBuilder createSpGEMMCopyBuilder = {
       "mgpuSpGEMMCopy",
       llvmVoidType,
       {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
        llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
-       llvmInt32Type /*ctp*/, llvmInt32Type /*alg*/,
-       llvmPointerType /*void *stream*/}};
+       llvmInt32Type /*ctp*/, llvmPointerType /*void *stream*/}};
   FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
       "mgpuSpGEMMCreateDescr",
       llvmPointerType,
@@ -1735,7 +1733,6 @@ ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
-  auto alg = genConstInt32From(rewriter, loc, adaptor.getAlg());
   auto stream = adaptor.getAsyncDependencies().front();
 
   Value pBuf =
@@ -1751,7 +1748,7 @@ ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
         createSpGEMMWorkEstimationBuilder
             .create(loc, rewriter,
                     {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
-                     adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, alg,
+                     adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
                      adaptor.getBufferSz(), pBuf, stream})
             .getResult();
   } else {
@@ -1759,7 +1756,7 @@ ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
         createSpGEMMComputeBuilder
             .create(loc, rewriter,
                     {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
-                     adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, alg,
+                     adaptor.getSpmatB(), adaptor.getSpmatC(), computeType,
                      adaptor.getBufferSz(), pBuf, stream})
             .getResult();
   }
@@ -1777,7 +1774,6 @@ ConvertSpGEMMEstimateMemoryOpToGpuRuntimeCallPattern::matchAndRewrite(
   Location loc = op.getLoc();
   auto computeType = genConstInt32From(
       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
-  auto alg = genConstInt32From(rewriter, loc, adaptor.getAlg());
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
   auto stream = adaptor.getAsyncDependencies().front();
@@ -1806,7 +1802,7 @@ ConvertSpGEMMEstimateMemoryOpToGpuRuntimeCallPattern::matchAndRewrite(
       loc, rewriter,
       {bufferSizePtr3, bufferSizePtr2, adaptor.getDesc(), modeA, modeB,
        adaptor.getSpmatA(), adaptor.getSpmatB(), adaptor.getSpmatC(),
-       computeType, alg, chunkFraction, adaptor.getBufferSz3(), pBuf3,
+       computeType, chunkFraction, adaptor.getBufferSz3(), pBuf3,
        adaptor.getBufferSz2(), stream});
   auto bufferSize2 =
       rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
@@ -1828,12 +1824,11 @@ LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
       rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
   auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
-  auto alg = genConstInt32From(rewriter, loc, adaptor.getAlg());
   auto stream = adaptor.getAsyncDependencies().front();
-  createSpGEMMCopyBuilder.create(
-      loc, rewriter,
-      {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
-       adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, alg, stream});
+  createSpGEMMCopyBuilder.create(loc, rewriter,
+                                 {adaptor.getDesc(), modeA, modeB,
+                                  adaptor.getSpmatA(), adaptor.getSpmatB(),
+                                  adaptor.getSpmatC(), computeType, stream});
   rewriter.replaceOp(op, {stream});
   return success();
 }

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index e747541bff5ab8..23c5e40e438189 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -605,11 +605,10 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSDDMM(int32_t ma, int32_t mb,
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMWorkEstimation(
     void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp,
-    int32_t alg, intptr_t bs, void *buf, CUstream /*stream*/) {
+    intptr_t bs, void *buf, CUstream /*stream*/) {
   cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
-  cusparseSpGEMMAlg_t algorithm = static_cast<cusparseSpGEMMAlg_t>(alg);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
@@ -619,15 +618,15 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMWorkEstimation(
 
   CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_workEstimation(
       cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
-      algorithm, spgemmDesc, &newBufferSize, buf))
+      CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize, buf))
   return newBufferSize == 0 ? 1 : newBufferSize; // avoid zero-alloc
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
 mgpuSpGEMMEstimateMemory(void *nbs3, void *nbs2, void *s, int32_t ma,
                          int32_t mb, void *a, void *b, void *c, int32_t ctp,
-                         int32_t alg, float chunk_fraction, intptr_t bs3,
-                         void *buf3, intptr_t bs2, CUstream /*stream*/) {
+                         float chunk_fraction, intptr_t bs3, void *buf3,
+                         intptr_t bs2, CUstream /*stream*/) {
   cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
@@ -640,11 +639,10 @@ mgpuSpGEMMEstimateMemory(void *nbs3, void *nbs2, void *s, int32_t ma,
   size_t *newBufferSize3 = reinterpret_cast<size_t *>(nbs3);
   *newBufferSize2 = bs2;
   *newBufferSize3 = bs3;
-  auto algorithm = static_cast<cusparseSpGEMMAlg_t>(alg);
 
   CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_estimateMemory(
       cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
-      algorithm, spgemmDesc, chunk_fraction, newBufferSize3, buf3,
+      CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, chunk_fraction, newBufferSize3, buf3,
       newBufferSize2))
   // avoid zero-alloc
   if (*newBufferSize2 == 0) {
@@ -656,13 +654,12 @@ mgpuSpGEMMEstimateMemory(void *nbs3, void *nbs2, void *s, int32_t ma,
   return;
 }
 
-extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMCompute(
-    void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp,
-    int32_t alg, intptr_t bsz2, void *buf2, CUstream /*stream*/) {
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
+mgpuSpGEMMCompute(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
+                  int32_t ctp, intptr_t bsz2, void *buf2, CUstream /*stream*/) {
   cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
-  cusparseSpGEMMAlg_t algorithm = static_cast<cusparseSpGEMMAlg_t>(alg);
   cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
   cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
@@ -671,13 +668,13 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMCompute(
   size_t newBufferSize2 = bsz2;
   CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_compute(
       cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
-      algorithm, spgemmDesc, &newBufferSize2, buf2))
+      CUSPARSE_SPGEMM_DEFAULT, spgemmDesc, &newBufferSize2, buf2))
   return newBufferSize2 == 0 ? 1 : newBufferSize2; // avoid zero-alloc
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
 mgpuSpGEMMCopy(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
-               int32_t ctp, int32_t alg, CUstream /*stream*/) {
+               int32_t ctp, CUstream /*stream*/) {
   cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
   cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
   cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
@@ -685,17 +682,15 @@ mgpuSpGEMMCopy(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
   cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
   cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
   auto cTp = static_cast<cudaDataType_t>(ctp);
-  auto algorithm = static_cast<cusparseSpGEMMAlg_t>(alg);
   ALPHABETA(cTp, alpha, beta)
 
-  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_copy(cusparse_env, modeA, modeB,
-                                               alphap, matA, matB, betap, matC,
-                                               cTp, algorithm, spgemmDesc))
+  CUSPARSE_REPORT_IF_ERROR(
+      cusparseSpGEMM_copy(cusparse_env, modeA, modeB, alphap, matA, matB, betap,
+                          matC, cTp, CUSPARSE_SPGEMM_DEFAULT, spgemmDesc))
 }
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
 mgpuSpGEMMCreateDescr(CUstream /*stream*/) {
-  // cusparseSpGEMMDescr_t is a pointer type
   cusparseSpGEMMDescr_t spgemmDesc = nullptr;
   CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_createDescr(&spgemmDesc))
   return reinterpret_cast<void *>(spgemmDesc);
@@ -703,7 +698,6 @@ mgpuSpGEMMCreateDescr(CUstream /*stream*/) {
 
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
 mgpuSpGEMMDestroyDescr(void *s, CUstream /*stream*/) {
-  // cusparseSpGEMMDescr_t is a pointer type
   cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
   CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_destroyDescr(spgemmDesc))
 }

diff  --git a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
index 5b0472b79c7635..40489295143862 100644
--- a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
@@ -120,36 +120,36 @@ module attributes {gpu.container_module} {
     // Used as nullptr
     %alloc = memref.alloc() : memref<0xi8>
     %c0 = arith.constant 0 : index
-    %bufferSz1, %token7 = gpu.spgemm_work_estimation_or_compute async 
+    %bufferSz1, %token7 = gpu.spgemm_work_estimation_or_compute async
                             [%token6]{WORK_ESTIMATION}
-                            %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE}, 
-                            %spmatC, ALG2, %spgemmDesc, %c0, 
+                            %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE},
+                            %spmatC, %spgemmDesc, %c0,
                             %alloc: f32 into memref<0xi8>
     %buf1, %token8 = gpu.alloc async [%token7] (%bufferSz1) : memref<?xi8>
     %bufferSz1_1, %token9 = gpu.spgemm_work_estimation_or_compute async
-                              [%token8]{WORK_ESTIMATION} %spmatA, %spmatB, 
-                              %spmatC, ALG2, %spgemmDesc, %bufferSz1, 
+                              [%token8]{WORK_ESTIMATION} %spmatA, %spmatB,
+                              %spmatC, %spgemmDesc, %bufferSz1,
                               %buf1: f32 into memref<?xi8>
-    %bufferSz3, %dummy, %token10 = gpu.spgemm_estimate_memory async [%token9] 
-                                     %spmatA, %spmatB, %spmatC, ALG2, 
-                                     %spgemmDesc, %c0, %c0, 
+    %bufferSz3, %dummy, %token10 = gpu.spgemm_estimate_memory async [%token9]
+                                     %spmatA, %spmatB, %spmatC,
+                                     %spgemmDesc, %c0, %c0,
                                      %alloc: f32 into memref<0xi8>
     %buf3, %token11 = gpu.alloc async [%token10] (%bufferSz3) : memref<?xi8>
-    %bufferSz3_2, %bufferSz2, %token12 = gpu.spgemm_estimate_memory async 
+    %bufferSz3_2, %bufferSz2, %token12 = gpu.spgemm_estimate_memory async
                                           [%token11] %spmatA, %spmatB, %spmatC,
-                                          ALG2, %spgemmDesc, %bufferSz3, %c0,
+                                          %spgemmDesc, %bufferSz3, %c0,
                                           %buf3: f32 into memref<?xi8>
     %buf2, %token13 = gpu.alloc async [%token12] (%bufferSz2) : memref<?xi8>
-    %bufferSz2_2, %token14 = gpu.spgemm_work_estimation_or_compute async 
-                               [%token13]{COMPUTE} %spmatA, %spmatB, %spmatC, 
-                               ALG2, %spgemmDesc, %bufferSz2, 
+    %bufferSz2_2, %token14 = gpu.spgemm_work_estimation_or_compute async
+                               [%token13]{COMPUTE} %spmatA, %spmatB, %spmatC,
+                               %spgemmDesc, %bufferSz2,
                                %buf2: f32 into memref<?xi8>
     %rows, %cols, %nnz, %token15 = gpu.spgemm_get_size async [%token14] %spmatC
     %mem_columns, %token16 = gpu.alloc async [%token15] (%cols) : memref<?xi32>
     %mem_values, %token17 = gpu.alloc async [%token16] (%nnz) : memref<?xf32>
     gpu.wait [%token17]
     %token18 = gpu.wait async
-    %token19 = gpu.spgemm_copy async [%token18] %spmatA, %spmatB, %spmatC, ALG2, %spgemmDesc: f32
+    %token19 = gpu.spgemm_copy async [%token18] %spmatA, %spmatB, %spmatC, %spgemmDesc: f32
     %token20 = gpu.destroy_sp_mat async [%token19] %spmatA
     %token21 = gpu.destroy_sp_mat async [%token20] %spmatB
     %token22 = gpu.destroy_sp_mat async [%token21] %spmatC
@@ -158,5 +158,3 @@ module attributes {gpu.container_module} {
   }
 
 }
-
-

diff  --git a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
index bf669bd3b46c99..b39bc101af17f9 100644
--- a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
+++ b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
@@ -64,19 +64,19 @@ module attributes {gpu.container_module} {
   // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_create_descr async [%{{.*}}]
   // CHECK:           %{{.*}} = memref.alloc() : memref<0xi8>
   // CHECK:           %{{.*}} = arith.constant 0 : index
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
   // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_estimate_memory async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
+  // CHECK:           %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_estimate_memory async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
   // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_estimate_memory async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
+  // CHECK:           %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_estimate_memory async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
   // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
-  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
   // CHECK:           %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_get_size async [%{{.*}}] %{{.*}}
   // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi32>
   // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf32>
   // CHECK:           gpu.wait [%{{.*}}]
-  // CHECK:           gpu.spgemm_copy  %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}} : f32
+  // CHECK:           gpu.spgemm_copy  %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32
   // CHECK:           gpu.destroy_sp_mat  %{{.*}}
   // CHECK:           gpu.destroy_sp_mat  %{{.*}}
   // CHECK:           gpu.destroy_sp_mat  %{{.*}}
@@ -92,35 +92,35 @@ module attributes {gpu.container_module} {
     // Used as nullptr
     %alloc = memref.alloc() : memref<0xi8>
     %c0 = arith.constant 0 : index
-    %bufferSz1, %token7 = gpu.spgemm_work_estimation_or_compute async 
+    %bufferSz1, %token7 = gpu.spgemm_work_estimation_or_compute async
                             [%token6]{WORK_ESTIMATION}
-                            %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE}, 
-                            %spmatC, ALG2, %spgemmDesc, %c0, 
+                            %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE},
+                            %spmatC, %spgemmDesc, %c0,
                             %alloc: f32 into memref<0xi8>
     %buf1, %token8 = gpu.alloc async [%token7] (%bufferSz1) : memref<?xi8>
-    %bufferSz1_1, %token9 = gpu.spgemm_work_estimation_or_compute async 
-                              [%token8]{WORK_ESTIMATION} %spmatA, %spmatB, 
-                              %spmatC, ALG2, %spgemmDesc, %bufferSz1, 
+    %bufferSz1_1, %token9 = gpu.spgemm_work_estimation_or_compute async
+                              [%token8]{WORK_ESTIMATION} %spmatA, %spmatB,
+                              %spmatC, %spgemmDesc, %bufferSz1,
                               %buf1: f32 into memref<?xi8>
-    %bufferSz3, %dummy, %token10 = gpu.spgemm_estimate_memory async [%token9] 
-                                     %spmatA, %spmatB, %spmatC, ALG2, 
-                                     %spgemmDesc, %c0, %c0, 
+    %bufferSz3, %dummy, %token10 = gpu.spgemm_estimate_memory async [%token9]
+                                     %spmatA, %spmatB, %spmatC,
+                                     %spgemmDesc, %c0, %c0,
                                      %alloc: f32 into memref<0xi8>
     %buf3, %token11 = gpu.alloc async [%token10] (%bufferSz3) : memref<?xi8>
-    %bufferSz3_2, %bufferSz2, %token12 = gpu.spgemm_estimate_memory async 
+    %bufferSz3_2, %bufferSz2, %token12 = gpu.spgemm_estimate_memory async
                                           [%token11] %spmatA, %spmatB, %spmatC,
-                                          ALG2, %spgemmDesc, %bufferSz3, %c0,
+                                          %spgemmDesc, %bufferSz3, %c0,
                                           %buf3: f32 into memref<?xi8>
     %buf2, %token13 = gpu.alloc async [%token12] (%bufferSz2) : memref<?xi8>
-    %bufferSz2_2, %token14 = gpu.spgemm_work_estimation_or_compute async 
-                               [%token13]{COMPUTE} %spmatA, %spmatB, %spmatC, 
-                               ALG2, %spgemmDesc, %bufferSz2, 
+    %bufferSz2_2, %token14 = gpu.spgemm_work_estimation_or_compute async
+                               [%token13]{COMPUTE} %spmatA, %spmatB, %spmatC,
+                               %spgemmDesc, %bufferSz2,
                                %buf2: f32 into memref<?xi8>
     %rows, %cols, %nnz, %token15 = gpu.spgemm_get_size async [%token14] %spmatC
     %mem_columns, %token16 = gpu.alloc async [%token15] (%cols) : memref<?xi32>
     %mem_values, %token17 = gpu.alloc async [%token16] (%nnz) : memref<?xf32>
     gpu.wait [%token17]
-    gpu.spgemm_copy %spmatA, %spmatB, %spmatC, ALG2, %spgemmDesc: f32
+    gpu.spgemm_copy %spmatA, %spmatB, %spmatC, %spgemmDesc: f32
     gpu.destroy_sp_mat %spmatA
     gpu.destroy_sp_mat %spmatB
     gpu.destroy_sp_mat %spmatC
@@ -154,5 +154,3 @@ module attributes {gpu.container_module} {
   }
 
 }
-
-


        


More information about the Mlir-commits mailing list