[Mlir-commits] [mlir] dfe2942 - [mlir][sparse][gpu] add spgemm operator

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 7 17:35:54 PDT 2023


Author: Kun Wu
Date: 2023-08-08T00:29:23Z
New Revision: dfe29429094852ca3b62752f2acf11b280e9f610

URL: https://github.com/llvm/llvm-project/commit/dfe29429094852ca3b62752f2acf11b280e9f610
DIFF: https://github.com/llvm/llvm-project/commit/dfe29429094852ca3b62752f2acf11b280e9f610.diff

LOG: [mlir][sparse][gpu] add spgemm operator

Differential Revision: https://reviews.llvm.org/D152981

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
    mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
    mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
    mlir/test/Dialect/GPU/sparse-roundtrip.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
index 842e63aff7bd9f..bb99d4b481ec51 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td
@@ -110,17 +110,16 @@ class MMAMatrixOf<list<Type> allowedTypes> :
   "gpu.mma_matrix", "::mlir::gpu::MMAMatrixType">;
 
 // Types for all sparse handles.
-def GPU_SparseDnTensorHandle :
-  DialectType<GPU_Dialect,
-    CPred<"llvm::isa<::mlir::gpu::SparseDnTensorHandleType>($_self)">,
-    "dense tensor handle type">,
-  BuildableType<"mlir::gpu::SparseDnTensorHandleType::get($_builder.getContext())">;
-
-def GPU_SparseSpMatHandle :
-  DialectType<GPU_Dialect,
-    CPred<"llvm::isa<::mlir::gpu::SparseSpMatHandleType>($_self)">,
-    "sparse matrix handle type">,
-  BuildableType<"mlir::gpu::SparseSpMatHandleType::get($_builder.getContext())">;
+class GPU_SparseHandle<string typeStr, string description> :
+    DialectType<GPU_Dialect,
+    CPred<"llvm::isa<::mlir::gpu::"#typeStr#">($_self)">,
+    description#" handle type">,
+  BuildableType<"mlir::gpu::"#typeStr#"::get($_builder.getContext())">;
+
+def GPU_SparseDnTensorHandle : GPU_SparseHandle<"SparseDnTensorHandleType", "dense tensor">;
+def GPU_SparseSpGEMMOpHandle : GPU_SparseHandle<"SparseSpGEMMOpHandleType", "SpGEMM operation">;
+def GPU_SparseSpMatHandle : GPU_SparseHandle<"SparseSpMatHandleType", "sparse matrix">;
+
 
 //===----------------------------------------------------------------------===//
 // GPU Interfaces.

diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 1178c0895b5024..c27306cb775b13 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -165,7 +165,7 @@ class MMAMatrixType
 void addAsyncDependency(Operation *op, Value token);
 
 // Handle types for sparse.
-enum class SparseHandleKind { SpMat, DnTensor };
+enum class SparseHandleKind { SpMat, DnTensor, SpGEMMOp };
 
 template <SparseHandleKind K>
 class SparseHandleType
@@ -178,6 +178,7 @@ class SparseHandleType
 
 using SparseDnTensorHandleType = SparseHandleType<SparseHandleKind::DnTensor>;
 using SparseSpMatHandleType = SparseHandleType<SparseHandleKind::SpMat>;
+using SparseSpGEMMOpHandleType = SparseHandleType<SparseHandleKind::SpGEMMOp>;
 
 } // namespace gpu
 } // namespace mlir

diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 3b20ba2b46e351..e83fe3303c17b5 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -2130,4 +2130,315 @@ def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
   }];
 }
 
+// ALG1, ALG2, ALG3 use 3--5 to align with cusparseSpGEMMAlg_t in cusparse.h.
+def GPU_SpGEMMAlg : I32EnumAttr<"SpGEMMAlg",
+    "selected algorithm for sparse matrix SpGEMM",
+    [
+      I32EnumAttrCase<"ALG1", 3>,
+      I32EnumAttrCase<"ALG2", 4>,
+      I32EnumAttrCase<"ALG3", 5>,
+    ]> {
+      let genSpecializedAttr = 0;
+      let cppNamespace = GPU_Dialect.cppNamespace;
+      let defaultValue = "SpGEMMAlg::ALG1";
+}
+
+def GPU_SpGEMMAlgAttr : EnumAttr<GPU_Dialect, GPU_SpGEMMAlg, "spgemm_alg"> {
+  let defaultValue = GPU_SpGEMMAlg.defaultValue;
+}
+
+def GPU_SpGEMMWorkEstimationOrComputeKind : I32EnumAttr<"SpGEMMWorkEstimationOrComputeKind",
+    "choose whether spgemm_work_estimation_or_compute does work estimation or compute",
+    [
+      I32EnumAttrCase<"WORK_ESTIMATION", 0>,
+      I32EnumAttrCase<"COMPUTE", 1>,
+    ]> {
+      let genSpecializedAttr = 0;
+      let cppNamespace = GPU_Dialect.cppNamespace;
+}
+
+def GPU_SpGEMMWorkEstimationOrComputeKindAttr : EnumAttr<GPU_Dialect, 
+    GPU_SpGEMMWorkEstimationOrComputeKind, 
+    "spgemm_work_estimation_or_compute_kind"> {}
+
+def GPU_SpGEMMCreateDescrOp : GPU_Op<"spgemm_create_descr", [GPU_AsyncOpInterface]> {
+  let summary = "SpGEMM Create Descr operation";
+  let description = [{
+    The `gpu.spgemm_create_descr` creates a descriptor for the SpGEMM operation. 
+    The descriptor describes the SpGEMM operation and stores the internal data
+    throughout the computation. It needs to be passed as an argument to
+    spgemm_* operations.
+
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a `!gpu.async.token` in addition to the environment.
+
+    Example:
+
+    ```mlir
+    %desc,  %token = gpu.spgemm_create_descr async [%dep]
+    ```
+
+  }];
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies);
+  let results = (outs GPU_SparseSpGEMMOpHandle:$desc,
+                      Optional<GPU_AsyncToken>:$asyncToken);
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+    attr-dict
+  }];
+}
+
+def GPU_SpGEMMDestroyDescrOp : GPU_Op<"spgemm_destroy_descr", [GPU_AsyncOpInterface]> {
+  let summary = "SpGEMM Destroy Descr operation";
+  let description = [{
+    The `gpu.spgemm_destroy_descr` destroys the SpGEMM operation descriptor.
+
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a `!gpu.async.token` in addition to the environment.
+
+    Example:
+
+    ```mlir
+    %token = gpu.spgemm_destroy_descr async [%dep] %desc
+    ```
+
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                       GPU_SparseSpGEMMOpHandle:$desc);
+  let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+    $desc attr-dict
+  }];
+}
+
+
+def GPU_SpGEMMWorkEstimationOrComputeOp : GPU_Op<"spgemm_work_estimation_or_compute", [GPU_AsyncOpInterface]> {
+  let summary = "SpGEMM work estimation operation";
+  let description = [{
+    The `gpu.spgemm_work_estimation_or_compute` is used to call 
+    cusparseSpGEMM_workEstimation or cusparseSpGEMM_compute. Both of them are
+    for both determining the buffer size and performing the actual computation.
+    The operation expects handles returned by previous sparse operations to
+    construct an environment and the operands for SpGEMM. 
+    The buffer must have been allocated on the device.
+
+
+    C' = alpha * op(A) * op(B) + beta * C
+
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a `!gpu.async.token` in addition to the environment.
+
+    Example:
+
+    ```mlir
+    %bufferSz, %token = gpu.spgemm_work_estimation_or_compute async [%dep]{COMPUTE}
+                          %desc, %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE}, 
+                          %spmatC, ALG2, %spgemmDesc, %c0, %alloc: f32 into 
+                          memref<0xi8>
+    ```
+
+    The matrix arguments can also be associated with one of the following
+    operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+    is NON_TRANSPOSE.
+
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                       GPU_SparseSpGEMMOpHandle:$desc,
+                       GPU_TransposeModeAttr:$modeA,
+                       GPU_TransposeModeAttr:$modeB,
+                       GPU_SparseSpMatHandle:$spmatA,
+                       GPU_SparseSpMatHandle:$spmatB,
+                       GPU_SparseSpMatHandle:$spmatC,
+                       TypeAttr:$computeType,
+                       Index:$bufferSz,
+                       GPU_SpGEMMAlgAttr:$alg,
+                       AnyMemRef:$buffer,
+                       GPU_SpGEMMWorkEstimationOrComputeKindAttr:$kind);
+  let results = (outs Res<Index>:$bufferSzNew,
+                      Optional<GPU_AsyncToken>:$asyncToken);
+
+  let builders = [OpBuilder<(ins
+    "Type":$bufferSzNew,
+    "Type":$asyncToken,
+    "ValueRange":$asyncDependencies,
+    "Value":$desc,
+    "Value":$spmatA,
+    "Value":$spmatB,
+    "Value":$spmatC,
+    "Type":$computeType,
+    "Value":$bufferSz,
+    "Value":$buffer), [{
+  auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+  auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
+  auto alg = gpu::SpGEMMAlg::ALG1;
+  auto kind = gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION;
+  return build($_builder, $_state, bufferSzNew, asyncToken, asyncDependencies, desc,
+               modeA, modeB, spmatA, spmatB, spmatC, computeType, bufferSz, alg, buffer, kind);}]>
+  ];
+
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+    `{` $kind `}` $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $alg `,` $desc `,` $bufferSz `,` $buffer  attr-dict `:` $computeType `into` type($buffer)
+  }];
+}
+
+
+def GPU_SpGEMMEstimateMemoryOp : GPU_Op<"spgemm_estimate_memory", [GPU_AsyncOpInterface]> {
+  let summary = "SpGEMM estimate memory operation";
+  let description = [{
+    The `gpu.spgemm_estimate_memory` is used for both determining the buffer
+    size and performing the actual computation.
+
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a `!gpu.async.token` in addition to the environment.
+
+    Example:
+
+    ```mlir
+    %bufferSz3, %dummy, %token = gpu.spgemm_estimate_memory async [%dep] %spmatA, %spmatB, %spmatC, ALG2, %spgemmDesc, %c0, %c0, %alloc: f32 into memref<0xi8>
+    ```
+
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                       GPU_SparseSpGEMMOpHandle:$desc,
+                       GPU_TransposeModeAttr:$modeA,
+                       GPU_TransposeModeAttr:$modeB,
+                       GPU_SparseSpMatHandle:$spmatA,
+                       GPU_SparseSpMatHandle:$spmatB,
+                       GPU_SparseSpMatHandle:$spmatC,
+                       TypeAttr:$computeType,
+                       GPU_SpGEMMAlgAttr:$alg,
+                       Index:$bufferSz3,
+                       AnyMemRef:$buffer3,
+                       Index:$bufferSz2);
+  let results = (outs Index:$bufferSz3New,
+                      Index:$bufferSz2New,
+                      Optional<GPU_AsyncToken>:$asyncToken);
+
+  let builders = [OpBuilder<(ins
+    "Type":$bufferSz3New,
+    "Type":$bufferSz2New,
+    "Type":$asyncToken,
+    "ValueRange":$asyncDependencies,
+    "Value":$desc,
+    "Value":$spmatA,
+    "Value":$spmatB,
+    "Value":$spmatC,
+    "Type":$computeType,
+    "Value":$bufferSz3,
+    "Value":$buffer3,
+    "Value":$bufferSz2), [{
+  auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+  auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
+  auto alg = gpu::SpGEMMAlg::ALG1;
+  return build($_builder, $_state, bufferSz3New, bufferSz2New, asyncToken, 
+               asyncDependencies, desc, modeA, modeB, spmatA, spmatB, spmatC, 
+               computeType, alg, bufferSz3, buffer3, bufferSz2);}]>
+  ];
+
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+    $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $alg `,` $desc `,` $bufferSz3 `,` $bufferSz2 `,` $buffer3 attr-dict `:` $computeType `into` type($buffer3)
+  }];
+}
+
+
+def GPU_SpGEMMCopyOp : GPU_Op<"spgemm_copy", [GPU_AsyncOpInterface]> {
+  let summary = "SpGEMM copy operation";
+  let description = [{
+    The `gpu.spgemm_copy` operation copies a sparse matrix, e.g., the result of
+    the SpGEMM computation.
+    
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a `!gpu.async.token` in addition to the environment.
+
+    Example:
+
+    ```mlir
+    gpu.spgemm_copy %spmatA, %spmatB, %spmatC, ALG2, %spgemmDesc: f32
+    ```
+
+    The matrix arguments can also be associated with one of the following
+    operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+    is NON_TRANSPOSE.
+
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                       GPU_SparseSpGEMMOpHandle:$desc,
+                       GPU_TransposeModeAttr:$modeA,
+                       GPU_TransposeModeAttr:$modeB,
+                       GPU_SparseSpMatHandle:$spmatA,
+                       GPU_SparseSpMatHandle:$spmatB,
+                       GPU_SparseSpMatHandle:$spmatC,
+                       TypeAttr:$computeType,
+                       GPU_SpGEMMAlgAttr:$alg);
+  let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
+
+  let builders = [OpBuilder<(ins
+    "Type":$asyncToken,
+    "ValueRange":$asyncDependencies,
+    "Value":$desc,
+    "Value":$spmatA,
+    "Value":$spmatB,
+    "Value":$spmatC,
+    "Type":$computeType), [{
+  auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+  auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
+  auto alg = gpu::SpGEMMAlg::ALG1;
+  return build($_builder, $_state, asyncToken, asyncDependencies, desc,
+               modeA, modeB, spmatA, spmatB, spmatC, computeType, alg);}]>
+  ];
+
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+    $spmatA (`{` $modeA^ `}`)? `,` $spmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $alg `,` $desc attr-dict `:` $computeType
+  }];
+}
+
+
+def GPU_SpGEMMGetSizeOp : GPU_Op<"spgemm_get_size", [GPU_AsyncOpInterface]> {
+  let summary = "SpGEMM get size operation";
+  let description = [{
+    The `gpu.spgemm_get_size` operation retrieves the number of rows, number of
+    columns, and number of non-zero elements of a sparse matrix.
+    
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a `!gpu.async.token` in addition to the environment.
+
+    Example:
+
+    ```mlir
+    %rows, %cols, %nnz, %token = gpu.spgemm_get_size async [%dep] %spmatC
+    ```
+
+    The matrix arguments can also be associated with one of the following
+    operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+    is NON_TRANSPOSE.
+
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                       GPU_SparseSpMatHandle:$spmat);
+  let results = (outs Index:$rows,
+                      Index:$cols,
+                      Index:$nnz,
+                      Optional<GPU_AsyncToken>:$asyncToken);
+
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+    $spmat attr-dict
+  }];
+}
+
 #endif // GPU_OPS

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 9993c093badc16..e72dcc89157227 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -89,6 +89,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
   Type llvmInt16Type = IntegerType::get(context, 16);
   Type llvmInt32Type = IntegerType::get(context, 32);
   Type llvmInt64Type = IntegerType::get(context, 64);
+  Type llvmFloat32Type = Float32Type::get(context);
   Type llvmInt8PointerType =
       this->getTypeConverter()->getPointerType(llvmInt8Type);
   Type llvmInt64PointerType =
@@ -294,6 +295,50 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       llvmVoidType,
       {llvmPointerType, llvmPointerType, llvmPointerType, llvmPointerType,
        llvmPointerType, llvmPointerType, llvmPointerType /*void *stream*/}};
+  FunctionCallBuilder createSpGEMMWorkEstimationBuilder = {
+      "mgpuSpGEMMWorkEstimation",
+      llvmIntPtrType,
+      {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
+       llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
+       llvmInt32Type /*ctp*/, llvmInt32Type /*alg*/, llvmIntPtrType /*bs*/,
+       llvmPointerType /*buf*/, llvmPointerType /*void *stream*/}};
+  FunctionCallBuilder createSpGEMMEstimateMemoryBuilder = {
+      "mgpuSpGEMMEstimateMemory",
+      llvmVoidType,
+      {llvmPointerType /*nbs3*/, llvmPointerType /*nbs2*/,
+       llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
+       llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
+       llvmInt32Type /*ctp*/, llvmInt32Type /*alg*/,
+       llvmFloat32Type /*chunk_fraction*/, llvmIntPtrType /*bs3*/,
+       llvmPointerType /*buf3*/, llvmIntPtrType /*bs2*/,
+       llvmPointerType /*void *stream*/}};
+  FunctionCallBuilder createSpGEMMComputeBuilder = {
+      "mgpuSpGEMMCompute",
+      llvmIntPtrType,
+      {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
+       llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
+       llvmInt32Type /*ctp*/, llvmInt32Type /*alg*/, llvmIntPtrType /*bs*/,
+       llvmPointerType /*buf*/, llvmPointerType /*void *stream*/}};
+  FunctionCallBuilder createSpGEMMCopyBuilder = {
+      "mgpuSpGEMMCopy",
+      llvmVoidType,
+      {llvmPointerType /*s*/, llvmInt32Type /*ma*/, llvmInt32Type /*mb*/,
+       llvmPointerType /*a*/, llvmPointerType /*b*/, llvmPointerType /*c*/,
+       llvmInt32Type /*ctp*/, llvmInt32Type /*alg*/,
+       llvmPointerType /*void *stream*/}};
+  FunctionCallBuilder createSpGEMMCreateDescrBuilder = {
+      "mgpuSpGEMMCreateDescr",
+      llvmPointerType,
+      {llvmPointerType /*void *stream*/}};
+  FunctionCallBuilder createSpGEMMDestroyDescrBuilder = {
+      "mgpuSpGEMMDestoryDescr",
+      llvmVoidType,
+      {llvmPointerType /*s*/, llvmPointerType /*void *stream*/}};
+  FunctionCallBuilder createSpGEMMGetSizeBuilder = {
+      "mgpuSpGEMMGetSize",
+      llvmVoidType,
+      {llvmPointerType /*mc*/, llvmPointerType /*rc*/, llvmPointerType /*cc*/,
+       llvmPointerType /*nc*/, llvmPointerType /*void *stream*/}};
 };
 
 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@@ -652,6 +697,28 @@ class ConvertSDDMMOpToGpuRuntimeCallPattern
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+// TODO: Apply this pattern to all GPU ops.
+#define DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(op_name)                \
+  class Convert##op_name##ToGpuRuntimeCallPattern                              \
+      : public ConvertOpToGpuRuntimeCallPattern<gpu::op_name> {                \
+  public:                                                                      \
+    Convert##op_name##ToGpuRuntimeCallPattern(                                 \
+        LLVMTypeConverter &typeConverter)                                      \
+        : ConvertOpToGpuRuntimeCallPattern<gpu::op_name>(typeConverter) {}     \
+                                                                               \
+  private:                                                                     \
+    LogicalResult                                                              \
+    matchAndRewrite(gpu::op_name op, OpAdaptor adaptor,                        \
+                    ConversionPatternRewriter &rewriter) const override;       \
+  };
+
+DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCreateDescrOp)
+DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMDestroyDescrOp)
+DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMWorkEstimationOrComputeOp)
+DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMEstimateMemoryOp)
+DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMGetSizeOp)
+DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SpGEMMCopyOp)
+
 } // namespace
 
 void GpuToLLVMConversionPass::runOnOperation() {
@@ -1337,6 +1404,14 @@ static Value genConstInt32From(OpBuilder &builder, Location loc, T TValue) {
                                           static_cast<int32_t>(TValue));
 }
 
+template <typename T>
+static Value genConstFloat32From(OpBuilder &builder, Location loc, T TValue) {
+  Type llvmFloat32Type = builder.getF32Type();
+  return builder.create<LLVM::ConstantOp>(
+      loc, llvmFloat32Type,
+      builder.getF32FloatAttr(static_cast<float>(TValue)));
+}
+
 LogicalResult ConvertCreateDnTensorOpToGpuRuntimeCallPattern::matchAndRewrite(
     gpu::CreateDnTensorOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -1772,6 +1847,187 @@ LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
+LogicalResult
+ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::SpGEMMCreateDescrOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
+      failed(isAsyncWithOneDependency(rewriter, op)))
+    return failure();
+  Location loc = op.getLoc();
+  auto stream = adaptor.getAsyncDependencies().front();
+  Value descr = createSpGEMMCreateDescrBuilder.create(loc, rewriter, {stream})
+                    .getResult();
+  rewriter.replaceOp(op, {descr, stream});
+  return success();
+}
+
+LogicalResult
+ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::SpGEMMDestroyDescrOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+
+  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
+      failed(isAsyncWithOneDependency(rewriter, op)))
+    return failure();
+  Location loc = op.getLoc();
+  auto stream = adaptor.getAsyncDependencies().front();
+  createSpGEMMCopyBuilder.create(loc, rewriter, {adaptor.getDesc(), stream});
+  rewriter.replaceOp(op, {stream});
+  return success();
+}
+
+LogicalResult
+ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::SpGEMMWorkEstimationOrComputeOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
+      failed(isAsyncWithOneDependency(rewriter, op)))
+    return failure();
+  Location loc = op.getLoc();
+  auto computeType = genConstInt32From(
+      rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
+  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
+  auto alg = genConstInt32From(rewriter, loc, adaptor.getAlg());
+  auto stream = adaptor.getAsyncDependencies().front();
+
+  Value pBuf =
+      MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
+  if (!getTypeConverter()->useOpaquePointers())
+    pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
+
+  Value bufferSizeNew;
+
+  if (adaptor.getKind() ==
+      gpu::SpGEMMWorkEstimationOrComputeKind::WORK_ESTIMATION) {
+    bufferSizeNew =
+        createSpGEMMWorkEstimationBuilder
+            .create(loc, rewriter,
+                    {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
+                     adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, alg,
+                     adaptor.getBufferSz(), pBuf, stream})
+            .getResult();
+  } else {
+    bufferSizeNew =
+        createSpGEMMComputeBuilder
+            .create(loc, rewriter,
+                    {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
+                     adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, alg,
+                     adaptor.getBufferSz(), pBuf, stream})
+            .getResult();
+  }
+  rewriter.replaceOp(op, {bufferSizeNew, stream});
+  return success();
+}
+
+LogicalResult
+ConvertSpGEMMEstimateMemoryOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::SpGEMMEstimateMemoryOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
+      failed(isAsyncWithOneDependency(rewriter, op)))
+    return failure();
+  Location loc = op.getLoc();
+  auto computeType = genConstInt32From(
+      rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
+  auto alg = genConstInt32From(rewriter, loc, adaptor.getAlg());
+  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
+  auto stream = adaptor.getAsyncDependencies().front();
+  // TODO: support other chunk fraction
+  Value chunkFraction = genConstFloat32From(rewriter, loc, 1.0);
+  Value pBuf3 =
+      MemRefDescriptor(adaptor.getBuffer3()).allocatedPtr(rewriter, loc);
+  if (!getTypeConverter()->useOpaquePointers())
+    pBuf3 = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf3);
+
+  auto two = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                               rewriter.getIndexAttr(2));
+  auto bufferSize = rewriter.create<LLVM::AllocaOp>(
+      loc, llvmInt64PointerType, llvmInt64Type, two, /*alignment=*/16);
+
+  auto bufferSizePtr2 = rewriter.create<LLVM::GEPOp>(
+      loc, llvmInt64PointerType, llvmInt64PointerType, bufferSize,
+      ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                                   rewriter.getIndexAttr(0))});
+  auto bufferSizePtr3 = rewriter.create<LLVM::GEPOp>(
+      loc, llvmInt64PointerType, llvmInt64PointerType, bufferSize,
+      ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                                   rewriter.getIndexAttr(1))});
+
+  createSpGEMMEstimateMemoryBuilder.create(
+      loc, rewriter,
+      {bufferSizePtr3, bufferSizePtr2, adaptor.getDesc(), modeA, modeB,
+       adaptor.getSpmatA(), adaptor.getSpmatB(), adaptor.getSpmatC(),
+       computeType, alg, chunkFraction, adaptor.getBufferSz3(), pBuf3,
+       adaptor.getBufferSz2(), stream});
+  auto bufferSize2 =
+      rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr2);
+  auto bufferSize3 =
+      rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, bufferSizePtr3);
+
+  rewriter.replaceOp(op, {bufferSize3, bufferSize2, stream});
+  return success();
+}
+
+LogicalResult ConvertSpGEMMCopyOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::SpGEMMCopyOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
+      failed(isAsyncWithOneDependency(rewriter, op)))
+    return failure();
+  Location loc = op.getLoc();
+  auto computeType = genConstInt32From(
+      rewriter, loc, getCuSparseDataTypeFrom(adaptor.getComputeType()));
+  auto modeA = genConstInt32From(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstInt32From(rewriter, loc, adaptor.getModeB());
+  auto alg = genConstInt32From(rewriter, loc, adaptor.getAlg());
+  auto stream = adaptor.getAsyncDependencies().front();
+  createSpGEMMCopyBuilder.create(
+      loc, rewriter,
+      {adaptor.getDesc(), modeA, modeB, adaptor.getSpmatA(),
+       adaptor.getSpmatB(), adaptor.getSpmatC(), computeType, alg, stream});
+  rewriter.replaceOp(op, {stream});
+  return success();
+}
+
+LogicalResult ConvertSpGEMMGetSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::SpGEMMGetSizeOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
+      failed(isAsyncWithOneDependency(rewriter, op)))
+    return failure();
+  Location loc = op.getLoc();
+  auto stream = adaptor.getAsyncDependencies().front();
+
+  auto three = rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                                 rewriter.getIndexAttr(3));
+  auto buffer = rewriter.create<LLVM::AllocaOp>(
+      loc, llvmInt64PointerType, llvmInt64Type, three, /*alignment=*/16);
+
+  auto rowsPtr = rewriter.create<LLVM::GEPOp>(
+      loc, llvmInt64PointerType, llvmInt64PointerType, buffer,
+      ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                                   rewriter.getIndexAttr(0))});
+  auto colsPtr = rewriter.create<LLVM::GEPOp>(
+      loc, llvmInt64PointerType, llvmInt64PointerType, buffer,
+      ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                                   rewriter.getIndexAttr(1))});
+  auto nnzsPtr = rewriter.create<LLVM::GEPOp>(
+      loc, llvmInt64PointerType, llvmInt64PointerType, buffer,
+      ValueRange{rewriter.create<LLVM::ConstantOp>(loc, getIndexType(),
+                                                   rewriter.getIndexAttr(2))});
+  createSpGEMMGetSizeBuilder.create(
+      loc, rewriter, {adaptor.getSpmat(), rowsPtr, colsPtr, nnzsPtr, stream});
+  auto rows = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, rowsPtr);
+  auto cols = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, colsPtr);
+  auto nnzs = rewriter.create<LLVM::LoadOp>(loc, llvmInt64Type, nnzsPtr);
+
+  rewriter.replaceOp(op, {rows, cols, nnzs, stream});
+  return success();
+}
+
 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns,
                                                StringRef gpuBinaryAnnotation,
@@ -1779,6 +2035,7 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
   addOpaquePointerConversion<gpu::AsyncTokenType>(converter);
   addOpaquePointerConversion<gpu::SparseDnTensorHandleType>(converter);
   addOpaquePointerConversion<gpu::SparseSpMatHandleType>(converter);
+  addOpaquePointerConversion<gpu::SparseSpGEMMOpHandleType>(converter);
 
   patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
                ConvertDeallocOpToGpuRuntimeCallPattern,
@@ -1797,6 +2054,12 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                ConvertCreateCsrOpToGpuRuntimeCallPattern,
                ConvertCreate2To4SpMatOpToGpuRuntimeCallPattern,
                ConvertDestroySpMatOpToGpuRuntimeCallPattern,
+               ConvertSpGEMMCreateDescrOpToGpuRuntimeCallPattern,
+               ConvertSpGEMMDestroyDescrOpToGpuRuntimeCallPattern,
+               ConvertSpGEMMWorkEstimationOrComputeOpToGpuRuntimeCallPattern,
+               ConvertSpGEMMEstimateMemoryOpToGpuRuntimeCallPattern,
+               ConvertSpGEMMGetSizeOpToGpuRuntimeCallPattern,
+               ConvertSpGEMMCopyOpToGpuRuntimeCallPattern,
                ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
                ConvertSpMVOpToGpuRuntimeCallPattern,
                ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,

diff  --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index c9f378c181e36d..77a2e01b5e0758 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -197,6 +197,7 @@ void GPUDialect::initialize() {
   addTypes<MMAMatrixType>();
   addTypes<SparseDnTensorHandleType>();
   addTypes<SparseSpMatHandleType>();
+  addTypes<SparseSpGEMMOpHandleType>();
   addOperations<
 #define GET_OP_LIST
 #include "mlir/Dialect/GPU/IR/GPUOps.cpp.inc"
@@ -214,6 +215,8 @@ static std::string getSparseHandleKeyword(SparseHandleKind kind) {
     return "sparse.dntensor_handle";
   case SparseHandleKind::SpMat:
     return "sparse.spmat_handle";
+  case SparseHandleKind::SpGEMMOp:
+    return "sparse.spgemmop_handle";
   }
   llvm_unreachable("unknown sparse handle kind");
   return "";
@@ -266,6 +269,8 @@ Type GPUDialect::parseType(DialectAsmParser &parser) const {
     return SparseDnTensorHandleType::get(context);
   if (keyword == getSparseHandleKeyword(SparseHandleKind::SpMat))
     return SparseSpMatHandleType::get(context);
+  if (keyword == getSparseHandleKeyword(SparseHandleKind::SpGEMMOp))
+    return SparseSpGEMMOpHandleType::get(context);
 
   parser.emitError(parser.getNameLoc(), "unknown gpu type: " + keyword);
   return Type();
@@ -280,6 +285,9 @@ void GPUDialect::printType(Type type, DialectAsmPrinter &os) const {
       })
       .Case<SparseSpMatHandleType>(
           [&](Type) { os << getSparseHandleKeyword(SparseHandleKind::SpMat); })
+      .Case<SparseSpGEMMOpHandleType>([&](Type) {
+        os << getSparseHandleKeyword(SparseHandleKind::SpGEMMOp);
+      })
       .Case<MMAMatrixType>([&](MMAMatrixType fragTy) {
         os << "mma_matrix<";
         auto shape = fragTy.getShape();

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 664293c53570d2..aece4c55454aa1 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -386,6 +386,7 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
 ///
 
 // Some macro magic to get float/double alpha and beta on host.
+// TODO: add support to passing alpha and beta as arguments
 #define ALPHABETA(dtp, alpha, beta)                                            \
   __nv_bfloat16(alpha##16bf) = 1.0f;                                           \
   __nv_bfloat16(beta##16bf) = 1.0f;                                            \
@@ -567,7 +568,6 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSpMM(int32_t ma, int32_t mb,
                                         CUSPARSE_SPMM_ALG_DEFAULT, buf))
 }
 
-// TODO: add support to passing alpha and beta as arguments
 extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
 mgpuSDDMMBufferSize(int32_t ma, int32_t mb, void *a, void *b, void *c,
                     int32_t ctp, CUstream /*stream*/) {
@@ -603,6 +603,121 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSDDMM(int32_t ma, int32_t mb,
                                          CUSPARSE_SDDMM_ALG_DEFAULT, buf))
 }
 
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMWorkEstimation(
+    void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp,
+    int32_t alg, intptr_t bs, void *buf, CUstream /*stream*/) {
+  cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
+  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
+  cusparseSpGEMMAlg_t algType = static_cast<cusparseSpGEMMAlg_t>(alg);
+  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
+  cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
+  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
+  auto cTp = static_cast<cudaDataType_t>(ctp);
+  ALPHABETA(cTp, alpha, beta)
+  size_t newBufferSize = bs;
+
+  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_workEstimation(
+      cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
+      algorithm, spgemmDesc, &newBufferSize, buf))
+  return newBufferSize == 0 ? 1 : newBufferSize; // avoid zero-alloc
+}
+
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuSpGEMMEstimateMemory(void *nbs3, void *nbs2, void *s, int32_t ma,
+                         int32_t mb, void *a, void *b, void *c, int32_t ctp,
+                         int32_t alg, float chunk_fraction, intptr_t bs3,
+                         void *buf3, intptr_t bs2, CUstream /*stream*/) {
+  cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
+  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
+  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
+  cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
+  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
+  auto cTp = static_cast<cudaDataType_t>(ctp);
+  ALPHABETA(cTp, alpha, beta)
+  size_t *newBufferSize2 = reinterpret_cast<size_t *>(nbs2);
+  size_t *newBufferSize3 = reinterpret_cast<size_t *>(nbs3);
+  *newBufferSize2 = bs2;
+  *newBufferSize3 = bs3;
+  auto algorithm = static_cast<cusparseSpGEMMAlg_t>(alg);
+
+  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_estimateMemory(
+      cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
+      algorithm, spgemmDesc, chunk_fraction, newBufferSize3, buf3,
+      newBufferSize2))
+  // avoid zero-alloc
+  if (*newBufferSize2 == 0) {
+    *newBufferSize2 = 1;
+  }
+  if (*newBufferSize3 == 0) {
+    *newBufferSize3 = 1;
+  }
+  return;
+}
+
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t mgpuSpGEMMCompute(
+    void *s, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t ctp,
+    int32_t alg, intptr_t bsz2, void *buf2, CUstream /*stream*/) {
+  cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
+  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
+  cusparseSpGEMMAlg_t algorithm = static_cast<cusparseSpGEMMAlg_t>(alg);
+  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
+  cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
+  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
+  auto cTp = static_cast<cudaDataType_t>(ctp);
+  ALPHABETA(cTp, alpha, beta)
+  size_t newBufferSize2 = bsz2;
+  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_compute(
+      cusparse_env, modeA, modeB, alphap, matA, matB, betap, matC, cTp,
+      algorithm, spgemmDesc, &newBufferSize2, buf2))
+  return newBufferSize2 == 0 ? 1 : newBufferSize2; // avoid zero-alloc
+}
+
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuSpGEMMCopy(void *s, int32_t ma, int32_t mb, void *a, void *b, void *c,
+               int32_t ctp, int32_t alg, CUstream /*stream*/) {
+  cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
+  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
+  cusparseSpMatDescr_t matA = reinterpret_cast<cusparseSpMatDescr_t>(a);
+  cusparseSpMatDescr_t matB = reinterpret_cast<cusparseSpMatDescr_t>(b);
+  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
+  auto cTp = static_cast<cudaDataType_t>(ctp);
+  auto algorithm = static_cast<cusparseSpGEMMAlg_t>(alg);
+  ALPHABETA(cTp, alpha, beta)
+
+  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_copy(cusparse_env, modeA, modeB,
+                                               alphap, matA, matB, betap, matC,
+                                               cTp, algorithm, spgemmDesc))
+}
+
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *
+mgpuSpGEMMCreateDescr(CUstream /*stream*/) {
+  // cusparseSpGEMMDescr_t is a pointer type
+  cusparseSpGEMMDescr_t spgemmDesc = nullptr;
+  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_createDescr(&spgemmDesc))
+  return reinterpret_cast<void *>(spgemmDesc);
+}
+
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuSpGEMMDestroyDescr(void *s, CUstream /*stream*/) {
+  // cusparseSpGEMMDescr_t is a pointer type
+  cusparseSpGEMMDescr_t spgemmDesc = reinterpret_cast<cusparseSpGEMMDescr_t>(s);
+  CUSPARSE_REPORT_IF_ERROR(cusparseSpGEMM_destroyDescr(spgemmDesc))
+}
+
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuSpGEMMGetSize(void *m, void *r, void *c, void *n, CUstream /*stream*/) {
+  cusparseConstSpMatDescr_t matDescr =
+      reinterpret_cast<cusparseConstSpMatDescr_t>(m);
+  int64_t *rows = reinterpret_cast<int64_t *>(r);
+  int64_t *cols = reinterpret_cast<int64_t *>(c);
+  int64_t *nnz = reinterpret_cast<int64_t *>(n);
+  CUSPARSE_REPORT_IF_ERROR(cusparseSpMatGetSize(matDescr, rows, cols, nnz));
+}
+
 #ifdef MLIR_ENABLE_CUDA_CUSPARSELT
 
 ///

diff  --git a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
index fff3e5954d577d..5b0472b79c7635 100644
--- a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
@@ -80,6 +80,83 @@ module attributes {gpu.container_module} {
     return
   }
 
+
+  // CHECK-LABEL:     func @spgemm
+  // CHECK: llvm.call @mgpuStreamCreate
+  // CHECK: llvm.call @mgpuMemAlloc
+  // CHECK: llvm.call @mgpuMemAlloc
+  // CHECK: llvm.call @mgpuCreateCsr
+  // CHECK: llvm.call @mgpuCreateCsr
+  // CHECK: llvm.call @mgpuCreateCsr
+  // CHECK: llvm.call @mgpuSpGEMMCreateDescr
+  // CHECK: llvm.call @malloc
+  // CHECK: llvm.call @mgpuSpGEMMWorkEstimation
+  // CHECK: llvm.call @mgpuMemAlloc
+  // CHECK: llvm.call @mgpuSpGEMMWorkEstimation
+  // CHECK: llvm.call @mgpuSpGEMMEstimateMemory
+  // CHECK: llvm.call @mgpuMemAlloc
+  // CHECK: llvm.call @mgpuSpGEMMEstimateMemory
+  // CHECK: llvm.call @mgpuMemAlloc
+  // CHECK: llvm.call @mgpuSpGEMMCompute
+  // CHECK: llvm.call @mgpuMemAlloc
+  // CHECK: llvm.call @mgpuMemAlloc
+  // CHECK: llvm.call @mgpuStreamSynchronize
+  // CHECK: llvm.call @mgpuStreamDestroy
+  // CHECK: llvm.call @mgpuStreamCreate
+  // CHECK: llvm.call @mgpuSpGEMMCopy
+  // CHECK: llvm.call @mgpuDestroySpMat
+  // CHECK: llvm.call @mgpuDestroySpMat
+  // CHECK: llvm.call @mgpuDestroySpMat
+  // CHECK: llvm.call @mgpuStreamSynchronize
+  // CHECK: llvm.call @mgpuStreamDestroy
+  func.func @spgemm(%arg0: index) {
+    %token0 = gpu.wait async
+    %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
+    %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
+    %spmatA, %token3 = gpu.create_csr async [%token2] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %spmatB, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %spmatC, %token5 = gpu.create_csr async [%token4] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %spgemmDesc, %token6 = gpu.spgemm_create_descr async [%token5]
+    // Used as nullptr
+    %alloc = memref.alloc() : memref<0xi8>
+    %c0 = arith.constant 0 : index
+    %bufferSz1, %token7 = gpu.spgemm_work_estimation_or_compute async 
+                            [%token6]{WORK_ESTIMATION}
+                            %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE}, 
+                            %spmatC, ALG2, %spgemmDesc, %c0, 
+                            %alloc: f32 into memref<0xi8>
+    %buf1, %token8 = gpu.alloc async [%token7] (%bufferSz1) : memref<?xi8>
+    %bufferSz1_1, %token9 = gpu.spgemm_work_estimation_or_compute async
+                              [%token8]{WORK_ESTIMATION} %spmatA, %spmatB, 
+                              %spmatC, ALG2, %spgemmDesc, %bufferSz1, 
+                              %buf1: f32 into memref<?xi8>
+    %bufferSz3, %dummy, %token10 = gpu.spgemm_estimate_memory async [%token9] 
+                                     %spmatA, %spmatB, %spmatC, ALG2, 
+                                     %spgemmDesc, %c0, %c0, 
+                                     %alloc: f32 into memref<0xi8>
+    %buf3, %token11 = gpu.alloc async [%token10] (%bufferSz3) : memref<?xi8>
+    %bufferSz3_2, %bufferSz2, %token12 = gpu.spgemm_estimate_memory async 
+                                          [%token11] %spmatA, %spmatB, %spmatC,
+                                          ALG2, %spgemmDesc, %bufferSz3, %c0,
+                                          %buf3: f32 into memref<?xi8>
+    %buf2, %token13 = gpu.alloc async [%token12] (%bufferSz2) : memref<?xi8>
+    %bufferSz2_2, %token14 = gpu.spgemm_work_estimation_or_compute async 
+                               [%token13]{COMPUTE} %spmatA, %spmatB, %spmatC, 
+                               ALG2, %spgemmDesc, %bufferSz2, 
+                               %buf2: f32 into memref<?xi8>
+    %rows, %cols, %nnz, %token15 = gpu.spgemm_get_size async [%token14] %spmatC
+    %mem_columns, %token16 = gpu.alloc async [%token15] (%cols) : memref<?xi32>
+    %mem_values, %token17 = gpu.alloc async [%token16] (%nnz) : memref<?xf32>
+    gpu.wait [%token17]
+    %token18 = gpu.wait async
+    %token19 = gpu.spgemm_copy async [%token18] %spmatA, %spmatB, %spmatC, ALG2, %spgemmDesc: f32
+    %token20 = gpu.destroy_sp_mat async [%token19] %spmatA
+    %token21 = gpu.destroy_sp_mat async [%token20] %spmatB
+    %token22 = gpu.destroy_sp_mat async [%token21] %spmatC
+    gpu.wait [%token22]
+    return
+  }
+
 }
 
 

diff  --git a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
index 2d07f8ceaf7274..bf669bd3b46c99 100644
--- a/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
+++ b/mlir/test/Dialect/GPU/sparse-roundtrip.mlir
@@ -54,6 +54,79 @@ module attributes {gpu.container_module} {
     return
   }
 
+  // CHECK-LABEL:     func @spgemm
+  // CHECK:      %{{.*}} = gpu.wait async
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xindex>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf64>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.create_csr async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<?xindex>, memref<?xindex>, memref<?xf64>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_create_descr async [%{{.*}}]
+  // CHECK:           %{{.*}} = memref.alloc() : memref<0xi8>
+  // CHECK:           %{{.*}} = arith.constant 0 : index
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
+  // CHECK:           %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_estimate_memory async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<0xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
+  // CHECK:           %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_estimate_memory async [%{{.*}}] %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi8>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.spgemm_work_estimation_or_compute async [%{{.*}}]{{{.*}}} %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}}, %{{.*}}, %{{.*}} : f32 into memref<?xi8>
+  // CHECK:           %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} = gpu.spgemm_get_size async [%{{.*}}] %{{.*}}
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xi32>
+  // CHECK:           %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xf32>
+  // CHECK:           gpu.wait [%{{.*}}]
+  // CHECK:           gpu.spgemm_copy  %{{.*}}, %{{.*}}, %{{.*}},  ALG2, %{{.*}} : f32
+  // CHECK:           gpu.destroy_sp_mat  %{{.*}}
+  // CHECK:           gpu.destroy_sp_mat  %{{.*}}
+  // CHECK:           gpu.destroy_sp_mat  %{{.*}}
+  // CHECK:           return
+  func.func @spgemm(%arg0: index) {
+    %token0 = gpu.wait async
+    %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
+    %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
+    %spmatA, %token3 = gpu.create_csr async [%token2] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %spmatB, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %spmatC, %token5 = gpu.create_csr async [%token4] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %spgemmDesc, %token6 = gpu.spgemm_create_descr async [%token5]
+    // Used as nullptr
+    %alloc = memref.alloc() : memref<0xi8>
+    %c0 = arith.constant 0 : index
+    %bufferSz1, %token7 = gpu.spgemm_work_estimation_or_compute async 
+                            [%token6]{WORK_ESTIMATION}
+                            %spmatA{NON_TRANSPOSE}, %spmatB{NON_TRANSPOSE}, 
+                            %spmatC, ALG2, %spgemmDesc, %c0, 
+                            %alloc: f32 into memref<0xi8>
+    %buf1, %token8 = gpu.alloc async [%token7] (%bufferSz1) : memref<?xi8>
+    %bufferSz1_1, %token9 = gpu.spgemm_work_estimation_or_compute async 
+                              [%token8]{WORK_ESTIMATION} %spmatA, %spmatB, 
+                              %spmatC, ALG2, %spgemmDesc, %bufferSz1, 
+                              %buf1: f32 into memref<?xi8>
+    %bufferSz3, %dummy, %token10 = gpu.spgemm_estimate_memory async [%token9] 
+                                     %spmatA, %spmatB, %spmatC, ALG2, 
+                                     %spgemmDesc, %c0, %c0, 
+                                     %alloc: f32 into memref<0xi8>
+    %buf3, %token11 = gpu.alloc async [%token10] (%bufferSz3) : memref<?xi8>
+    %bufferSz3_2, %bufferSz2, %token12 = gpu.spgemm_estimate_memory async 
+                                          [%token11] %spmatA, %spmatB, %spmatC,
+                                          ALG2, %spgemmDesc, %bufferSz3, %c0,
+                                          %buf3: f32 into memref<?xi8>
+    %buf2, %token13 = gpu.alloc async [%token12] (%bufferSz2) : memref<?xi8>
+    %bufferSz2_2, %token14 = gpu.spgemm_work_estimation_or_compute async 
+                               [%token13]{COMPUTE} %spmatA, %spmatB, %spmatC, 
+                               ALG2, %spgemmDesc, %bufferSz2, 
+                               %buf2: f32 into memref<?xi8>
+    %rows, %cols, %nnz, %token15 = gpu.spgemm_get_size async [%token14] %spmatC
+    %mem_columns, %token16 = gpu.alloc async [%token15] (%cols) : memref<?xi32>
+    %mem_values, %token17 = gpu.alloc async [%token16] (%nnz) : memref<?xf32>
+    gpu.wait [%token17]
+    gpu.spgemm_copy %spmatA, %spmatB, %spmatC, ALG2, %spgemmDesc: f32
+    gpu.destroy_sp_mat %spmatA
+    gpu.destroy_sp_mat %spmatB
+    gpu.destroy_sp_mat %spmatC
+    return
+  }
+
   // CHECK-LABEL: func @sddmm
   // CHECK: %{{.*}} = gpu.wait async
   // CHECK: %{{.*}}, %{{.*}} = gpu.alloc async [%{{.*}}] (%{{.*}}) : memref<?xindex>


        


More information about the Mlir-commits mailing list