[Mlir-commits] [mlir] cf44847 - [mlir][gpu][sparse] adding cusparse sddmm support

Kun Wu llvmlistbot at llvm.org
Sat May 27 13:01:50 PDT 2023


Author: Kun Wu
Date: 2023-05-27T20:01:41Z
New Revision: cf44847b4d1edb43de7ee917ddccf7fa397c63cb

URL: https://github.com/llvm/llvm-project/commit/cf44847b4d1edb43de7ee917ddccf7fa397c63cb
DIFF: https://github.com/llvm/llvm-project/commit/cf44847b4d1edb43de7ee917ddccf7fa397c63cb.diff

LOG: [mlir][gpu][sparse] adding cusparse sddmm support

Differential Revision: https://reviews.llvm.org/D151279

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
    mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
    mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
    mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
    mlir/test/Dialect/GPU/ops.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
index 6b8ede2071af4..a401feea3d075 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
@@ -2047,4 +2047,109 @@ def GPU_SpMMOp : GPU_Op<"spmm", [GPU_AsyncOpInterface]> {
   }];
 }
 
+def GPU_SDDMMBufferSizeOp : GPU_Op<"sddmm_buffer_size", [GPU_AsyncOpInterface]> {
+  let summary = "Precompute buffersize for SDDMM operation";
+  let description = [{
+    The `gpu.sddmm_buffer_size` operation returns the buffer size required
+    to perform the SDDMM operation on the given sparse and dense matrices.
+    The operation expects handles returned by previous sparse operations
+    to construct an environment and the operands for SDDMM.
+
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a !gpu.async.token in addition to the environment.
+
+    Example:
+
+    ```mlir
+    %buffersz, %token = gpu.sddmm_buffer_size async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC
+    ```
+
+    The matrix arguments can also be associated with one of the following 
+    operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+    is NON_TRANSPOSE.
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                   GPU_SparseEnvHandle:$env,
+                   GPU_TransposeModeAttr:$modeA,
+                   GPU_TransposeModeAttr:$modeB,
+                   GPU_SparseDnMatHandle:$dnmatA,
+                   GPU_SparseDnMatHandle:$dnmatB,
+                   GPU_SparseSpMatHandle:$spmatC);
+  let results = (outs Res<Index>:$bufferSz, Optional<GPU_AsyncToken>:$asyncToken);
+
+  let builders = [OpBuilder<(ins
+      "::mlir::Type":$bufferSz,
+      "::mlir::Type":$asyncToken,
+      "::mlir::ValueRange":$asyncDependencies,
+      "::mlir::Value":$env,
+      "::mlir::Value":$dnmatA,
+      "::mlir::Value":$dnmatB,
+      "::mlir::Value":$spmatC), [{
+    auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+    auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
+    return build($_builder, $_state, bufferSz, asyncToken, asyncDependencies, 
+                 env, modeA, modeB, dnmatA, dnmatB, spmatC);}]>
+  ];
+
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+    $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC attr-dict
+  }];
+}
+
+def GPU_SDDMMOp : GPU_Op<"sddmm", [GPU_AsyncOpInterface]> {
+  let summary = "SDDMM operation";
+  let description = [{
+    The `gpu.sddmm` operation performs the SDDMM operation on the given sparse and
+    dense matrices, and buffer.  The operation expects handles returned by previous
+    sparse operations to construct an environment and the operands for SDDMM. The
+    buffer must have been allocated on the device.
+
+    If the `async` keyword is present, the op is executed asynchronously (i.e.
+    it does not block until the execution has finished on the device). In
+    that case, it returns a !gpu.async.token in addition to the environment.
+
+    Example:
+
+    ```mlir
+    %token = gpu.sddmm async [%dep] %env, %dnmatA{TRANSPOSE}, %dnmatB{TRANSPOSE}, %spmatC, %buffer
+    ```
+
+    The matrix arguments can also be associated with one of the following 
+    operators: NON_TRANSPOSE, TRANSPOSE, CONJUGATE_TRANSPOSE. The default value
+    is NON_TRANSPOSE.
+  }];
+
+  let arguments = (ins Variadic<GPU_AsyncToken>:$asyncDependencies,
+                   GPU_SparseEnvHandle:$env,
+                   GPU_TransposeModeAttr:$modeA,
+                   GPU_TransposeModeAttr:$modeB,
+                   GPU_SparseDnMatHandle:$dnmatA,
+                   GPU_SparseDnMatHandle:$dnmatB,
+                   GPU_SparseSpMatHandle:$spmatC,
+                   AnyMemRef:$buffer);
+  let results = (outs Optional<GPU_AsyncToken>:$asyncToken);
+
+  let builders = [OpBuilder<(ins
+    "::mlir::Type":$asyncToken,
+    "::mlir::ValueRange":$asyncDependencies,
+    "::mlir::Value":$env,
+    "::mlir::Value":$dnmatA,
+    "::mlir::Value":$dnmatB,
+    "::mlir::Value":$spmatC,
+    "::mlir::Value":$buffer), [{
+  auto modeA = gpu::TransposeMode::NON_TRANSPOSE;
+  auto modeB = gpu::TransposeMode::NON_TRANSPOSE;
+  return build($_builder, $_state, asyncToken, asyncDependencies, env, modeA, 
+                modeB, dnmatA, dnmatB, spmatC, buffer);}]>
+  ];
+
+  let assemblyFormat = [{
+    custom<AsyncDependencies>(type($asyncToken), $asyncDependencies)
+    $env `,` $dnmatA (`{` $modeA^ `}`)? `,` $dnmatB (`{` $modeB^ `}`)? `,` $spmatC `,` $buffer attr-dict `:` type($buffer)
+  }];
+}
+
 #endif // GPU_OPS

diff  --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 029c1005e58ca..07ca1e51ed696 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -257,6 +257,18 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
       {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
        llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType,
        llvmPointerType /* void *stream */}};
+  FunctionCallBuilder SDDMMBufferSizeCallBuilder = {
+      "mgpuSDDMMBufferSize",
+      llvmIntPtrType,
+      {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
+       llvmPointerType, llvmPointerType, llvmInt32Type,
+       llvmPointerType /* void *stream */}};
+  FunctionCallBuilder SDDMMCallBuilder = {
+      "mgpuSDDMM",
+      llvmVoidType,
+      {llvmPointerType, llvmInt32Type, llvmInt32Type, llvmPointerType,
+       llvmPointerType, llvmPointerType, llvmInt32Type, llvmPointerType,
+       llvmPointerType /* void *stream */}};
 };
 
 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
@@ -599,6 +611,20 @@ class ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+class ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern
+    : public ConvertOpToGpuRuntimeCallPattern<gpu::SDDMMBufferSizeOp> {
+public:
+  ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern(
+      LLVMTypeConverter &typeConverter)
+      : ConvertOpToGpuRuntimeCallPattern<gpu::SDDMMBufferSizeOp>(
+            typeConverter) {}
+
+private:
+  LogicalResult
+  matchAndRewrite(gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 class ConvertSpMMOpToGpuRuntimeCallPattern
     : public ConvertOpToGpuRuntimeCallPattern<gpu::SpMMOp> {
 public:
@@ -611,6 +637,18 @@ class ConvertSpMMOpToGpuRuntimeCallPattern
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+class ConvertSDDMMOpToGpuRuntimeCallPattern
+    : public ConvertOpToGpuRuntimeCallPattern<gpu::SDDMMOp> {
+public:
+  ConvertSDDMMOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
+      : ConvertOpToGpuRuntimeCallPattern<gpu::SDDMMOp>(typeConverter) {}
+
+private:
+  LogicalResult
+  matchAndRewrite(gpu::SDDMMOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 void GpuToLLVMConversionPass::runOnOperation() {
@@ -1245,7 +1283,8 @@ LogicalResult ConvertCreateDnVecOpToGpuRuntimeCallPattern::matchAndRewrite(
       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
   if (!getTypeConverter()->useOpaquePointers())
     pVec = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pVec);
-  Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
+  Type dType =
+      llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                               dType.getIntOrFloatBitWidth());
   auto handle =
@@ -1281,7 +1320,8 @@ LogicalResult ConvertCreateDnMatOpToGpuRuntimeCallPattern::matchAndRewrite(
       MemRefDescriptor(adaptor.getMemref()).allocatedPtr(rewriter, loc);
   if (!getTypeConverter()->useOpaquePointers())
     pMat = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pMat);
-  Type dType = llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
+  Type dType =
+      llvm::cast<MemRefType>(op.getMemref().getType()).getElementType();
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                               dType.getIntOrFloatBitWidth());
   auto handle =
@@ -1325,8 +1365,10 @@ LogicalResult ConvertCreateCooOpToGpuRuntimeCallPattern::matchAndRewrite(
     pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
     pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
   }
-  Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
-  Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
+  Type iType =
+      llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
+  Type dType =
+      llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
   auto iw = rewriter.create<LLVM::ConstantOp>(
       loc, llvmInt32Type, iType.isIndex() ? 64 : iType.getIntOrFloatBitWidth());
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
@@ -1360,9 +1402,12 @@ LogicalResult ConvertCreateCsrOpToGpuRuntimeCallPattern::matchAndRewrite(
     pColIdxs = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pColIdxs);
     pValues = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pValues);
   }
-  Type pType = llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
-  Type iType = llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
-  Type dType = llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
+  Type pType =
+      llvm::cast<MemRefType>(op.getRowPos().getType()).getElementType();
+  Type iType =
+      llvm::cast<MemRefType>(op.getColIdxs().getType()).getElementType();
+  Type dType =
+      llvm::cast<MemRefType>(op.getValues().getType()).getElementType();
   auto pw = rewriter.create<LLVM::ConstantOp>(
       loc, llvmInt32Type, pType.isIndex() ? 64 : pType.getIntOrFloatBitWidth());
   auto iw = rewriter.create<LLVM::ConstantOp>(
@@ -1445,9 +1490,9 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
-  Type dType = getSpMatElemType(op.getSpmatA());
   auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
   auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
+  Type dType = getSpMatElemType(op.getSpmatA());
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                               dType.getIntOrFloatBitWidth());
   auto stream = adaptor.getAsyncDependencies().front();
@@ -1461,6 +1506,29 @@ LogicalResult ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
   return success();
 }
 
+LogicalResult ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::SDDMMBufferSizeOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
+      failed(isAsyncWithOneDependency(rewriter, op)))
+    return failure();
+  Location loc = op.getLoc();
+  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
+  Type dType = getSpMatElemType(op.getSpmatC());
+  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
+                                              dType.getIntOrFloatBitWidth());
+  auto stream = adaptor.getAsyncDependencies().front();
+  auto bufferSize =
+      SDDMMBufferSizeCallBuilder
+          .create(loc, rewriter,
+                  {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(),
+                   adaptor.getDnmatB(), adaptor.getSpmatC(), dw, stream})
+          .getResult();
+  rewriter.replaceOp(op, {bufferSize, stream});
+  return success();
+}
+
 LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
     gpu::SpMMOp op, OpAdaptor adaptor,
     ConversionPatternRewriter &rewriter) const {
@@ -1468,11 +1536,11 @@ LogicalResult ConvertSpMMOpToGpuRuntimeCallPattern::matchAndRewrite(
       failed(isAsyncWithOneDependency(rewriter, op)))
     return failure();
   Location loc = op.getLoc();
+  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
   Type dType = getSpMatElemType(op.getSpmatA());
   auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
                                               dType.getIntOrFloatBitWidth());
-  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
-  auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
   auto stream = adaptor.getAsyncDependencies().front();
   Value pBuf =
       MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
@@ -1494,6 +1562,31 @@ static void addOpaquePointerConversion(LLVMTypeConverter &converter) {
   });
 }
 
+LogicalResult ConvertSDDMMOpToGpuRuntimeCallPattern::matchAndRewrite(
+    gpu::SDDMMOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)) ||
+      failed(isAsyncWithOneDependency(rewriter, op)))
+    return failure();
+  Location loc = op.getLoc();
+  Type dType = getSpMatElemType(op.getSpmatC());
+  auto dw = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
+                                              dType.getIntOrFloatBitWidth());
+  auto modeA = genConstFrom(rewriter, loc, adaptor.getModeA());
+  auto modeB = genConstFrom(rewriter, loc, adaptor.getModeB());
+  auto stream = adaptor.getAsyncDependencies().front();
+  Value pBuf =
+      MemRefDescriptor(adaptor.getBuffer()).allocatedPtr(rewriter, loc);
+  if (!getTypeConverter()->useOpaquePointers())
+    pBuf = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pBuf);
+  SDDMMCallBuilder.create(loc, rewriter,
+                          {adaptor.getEnv(), modeA, modeB, adaptor.getDnmatA(),
+                           adaptor.getDnmatB(), adaptor.getSpmatC(), dw, pBuf,
+                           stream});
+  rewriter.replaceOp(op, {stream});
+  return success();
+}
+
 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                                RewritePatternSet &patterns,
                                                StringRef gpuBinaryAnnotation,
@@ -1526,7 +1619,9 @@ void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
                ConvertSpMVBufferSizeOpToGpuRuntimeCallPattern,
                ConvertSpMVOpToGpuRuntimeCallPattern,
                ConvertSpMMBufferSizeOpToGpuRuntimeCallPattern,
-               ConvertSpMMOpToGpuRuntimeCallPattern>(converter);
+               ConvertSpMMOpToGpuRuntimeCallPattern,
+               ConvertSDDMMBufferSizeOpToGpuRuntimeCallPattern,
+               ConvertSDDMMOpToGpuRuntimeCallPattern>(converter);
   patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(
       converter, gpuBinaryAnnotation, kernelBarePtrCallConv);
   patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());

diff  --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index 00a30d134842a..a87834e10f0b4 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -404,3 +404,37 @@ mgpuSpMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c, int32_t dw,
                                         matB, betap, matC, dtp,
                                         CUSPARSE_SPMM_ALG_DEFAULT, buf))
 }
+
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT intptr_t
+mgpuSDDMMBufferSize(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
+                    int32_t dw, CUstream /*stream*/) {
+  cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
+  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
+  cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
+  cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
+  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
+  cudaDataType_t dtp = dataTp(dw);
+  ALPHABETA(dw, alpha, beta)
+  size_t bufferSize = 0;
+  CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM_bufferSize(
+      handle, modeA, modeB, &alpha, matA, matB, &beta, matC, dtp,
+      CUSPARSE_SDDMM_ALG_DEFAULT, &bufferSize))
+  return bufferSize == 0 ? 1 : bufferSize; // avoid zero-alloc
+}
+
+extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
+mgpuSDDMM(void *h, int32_t ma, int32_t mb, void *a, void *b, void *c,
+          int32_t dw, void *buf, CUstream /*stream*/) {
+  cusparseHandle_t handle = reinterpret_cast<cusparseHandle_t>(h);
+  cusparseOperation_t modeA = static_cast<cusparseOperation_t>(ma);
+  cusparseOperation_t modeB = static_cast<cusparseOperation_t>(mb);
+  cusparseDnMatDescr_t matA = reinterpret_cast<cusparseDnMatDescr_t>(a);
+  cusparseDnMatDescr_t matB = reinterpret_cast<cusparseDnMatDescr_t>(b);
+  cusparseSpMatDescr_t matC = reinterpret_cast<cusparseSpMatDescr_t>(c);
+  cudaDataType_t dtp = dataTp(dw);
+  ALPHABETA(dw, alpha, beta)
+  CUSPARSE_REPORT_IF_ERROR(cusparseSDDMM(handle, modeA, modeB, &alpha, matA,
+                                         matB, &beta, matC, dtp,
+                                         CUSPARSE_SDDMM_ALG_DEFAULT, buf))
+}
\ No newline at end of file

diff  --git a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
index dcef27357b531..678842361b7a3 100644
--- a/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
+++ b/mlir/test/Conversion/GPUCommon/lower-sparse-to-gpu-runtime-calls.mlir
@@ -62,6 +62,36 @@ module attributes {gpu.container_module} {
     return
   }
 
+  // CHECK-LABEL: func @sddmm
+  // CHECK: llvm.call @mgpuStreamCreate
+  // CHECK: llvm.call @mgpuMemAlloc
+  // CHECK: llvm.call @mgpuMemAlloc
+  // CHECK: llvm.call @mgpuCreateSparseEnv
+  // CHECK: llvm.call @mgpuCreateCsr
+  // CHECK: llvm.call @mgpuCreateDnMat
+  // CHECK: llvm.call @mgpuSDDMMBufferSize
+  // CHECK: llvm.call @mgpuSDDMM
+  // CHECK: llvm.call @mgpuDestroySpMat
+  // CHECK: llvm.call @mgpuDestroyDnMat
+  // CHECK: llvm.call @mgpuDestroySparseEnv
+  // CHECK: llvm.call @mgpuStreamSynchronize
+  // CHECK: llvm.call @mgpuStreamDestroy
+  func.func @sddmm(%arg0: index) {
+    %token0 = gpu.wait async
+    %mem1, %token1 = gpu.alloc async [%token0] (%arg0) : memref<?xindex>
+    %mem2, %token2 = gpu.alloc async [%token1] (%arg0) : memref<?xf64>
+    %env, %token3 = gpu.create_sparse_env async [%token2]
+    %spmat, %token4 = gpu.create_csr async [%token3] %arg0, %arg0, %arg0, %mem1, %mem1, %mem2 : memref<?xindex>, memref<?xindex>, memref<?xf64>
+    %dnmat, %token5 = gpu.create_dn_mat async [%token4] %arg0, %arg0, %mem2 : memref<?xf64>
+    %bufferSz, %token6 = gpu.sddmm_buffer_size async [%token5] %env, %dnmat, %dnmat, %spmat
+    %token7 = gpu.sddmm async [%token6] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64>
+    %token8 = gpu.destroy_sp_mat async [%token7] %spmat
+    %token9 = gpu.destroy_dn_mat async [%token8] %dnmat
+    %token10 = gpu.destroy_sparse_env async [%token9] %env
+    gpu.wait [%token10]
+    return
+  }
+
 }
 
 

diff  --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir
index d6c1bef340c93..8900c5bfee581 100644
--- a/mlir/test/Dialect/GPU/ops.mlir
+++ b/mlir/test/Dialect/GPU/ops.mlir
@@ -344,16 +344,20 @@ module attributes {gpu.container_module} {
     %bufferSz2, %token10 = gpu.spmm_buffer_size async [%token9] %env, %spmat, %dnmat, %dnmat
     // CHECK: gpu.spmm async
     %token11 = gpu.spmm async [%token10] %env, %spmat, %dnmat, %dnmat, %mem2 : memref<?xf64>
+    // CHECK: gpu.sddmm_buffer_size async
+    %bufferSz3, %token12 = gpu.sddmm_buffer_size async [%token11] %env, %dnmat, %dnmat, %spmat
+    // CHECK: gpu.sddmm async
+    %token13 = gpu.sddmm async [%token12] %env, %dnmat, %dnmat, %spmat, %mem2 : memref<?xf64>
     // CHECK: gpu.destroy_dn_mat async
-    %token12 = gpu.destroy_dn_mat async [%token11] %dnmat
+    %token14 = gpu.destroy_dn_mat async [%token13] %dnmat
     // CHECK: gpu.destroy_sp_mat async
-    %token13 = gpu.destroy_sp_mat async [%token12] %spmat
+    %token15 = gpu.destroy_sp_mat async [%token14] %spmat
     // CHECK: gpu.destroy_dn_vec async
-    %token14 = gpu.destroy_dn_vec async [%token13] %dnvec
+    %token16 = gpu.destroy_dn_vec async [%token15] %dnvec
     // CHECK: gpu.destroy_sparse_env async
-    %token15 = gpu.destroy_sparse_env async [%token14] %env
+    %token17 = gpu.destroy_sparse_env async [%token16] %env
     // CHECK: gpu.wait
-    gpu.wait [%token15]
+    gpu.wait [%token17]
     return
   }
 }


        


More information about the Mlir-commits mailing list